This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 93a934e7753 [Improve](map) support map cast with map literal and 
implicate nested scala cast  (#26126)
93a934e7753 is described below

commit 93a934e7753dbf54eed84c311fcd3027880cf71c
Author: amory <[email protected]>
AuthorDate: Thu Nov 2 22:56:42 2023 +0800

    [Improve](map) support map cast with map literal and implicate nested scala 
cast  (#26126)
---
 be/src/vec/functions/function_cast.h               | 58 ++++++++++++++++++---
 .../java/org/apache/doris/catalog/MapType.java     |  6 ++-
 .../main/java/org/apache/doris/catalog/Type.java   | 11 ++++
 .../java/org/apache/doris/analysis/CastExpr.java   |  2 +-
 .../main/java/org/apache/doris/analysis/Expr.java  |  7 +++
 .../cast_function/test_cast_map_function.out       | 48 +++++++++++++++++
 .../cast_function/test_cast_map_function.out       | 48 +++++++++++++++++
 .../cast_function/test_cast_map_function.groovy    | 60 ++++++++++++++++++++++
 .../cast_function/test_cast_map_function.groovy    | 60 ++++++++++++++++++++++
 9 files changed, 291 insertions(+), 9 deletions(-)

diff --git a/be/src/vec/functions/function_cast.h 
b/be/src/vec/functions/function_cast.h
index 8872f373b44..e2bc77f0cc3 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -51,6 +51,7 @@
 #include "vec/aggregate_functions/aggregate_function.h"
 #include "vec/columns/column.h"
 #include "vec/columns/column_array.h"
+#include "vec/columns/column_map.h"
 #include "vec/columns/column_nullable.h"
 #include "vec/columns/column_string.h"
 #include "vec/columns/column_struct.h"
@@ -1899,13 +1900,57 @@ private:
     }
 
     //TODO(Amory) . Need support more cast for key , value for map
-    WrapperType create_map_wrapper(const DataTypePtr& from_type, const 
DataTypeMap& to_type) const {
-        switch (from_type->get_type_id()) {
-        case TypeIndex::String:
+    WrapperType create_map_wrapper(FunctionContext* context, const 
DataTypePtr& from_type,
+                                   const DataTypeMap& to_type) const {
+        if (from_type->get_type_id() == TypeIndex::String) {
             return &ConvertImplGenericFromString::execute;
-        default:
-            return create_unsupport_wrapper(from_type->get_name(), 
to_type.get_name());
         }
+        auto from = check_and_get_data_type<DataTypeMap>(from_type.get());
+        if (!from) {
+            return create_unsupport_wrapper(
+                    fmt::format("CAST AS Map can only be performed between Map 
types or from "
+                                "String. from type: {}, to type: {}",
+                                from_type->get_name(), to_type.get_name()));
+        }
+        DataTypes from_kv_types;
+        DataTypes to_kv_types;
+        from_kv_types.reserve(2);
+        to_kv_types.reserve(2);
+        from_kv_types.push_back(from->get_key_type());
+        from_kv_types.push_back(from->get_value_type());
+        to_kv_types.push_back(to_type.get_key_type());
+        to_kv_types.push_back(to_type.get_value_type());
+
+        auto kv_wrappers = get_element_wrappers(context, from_kv_types, 
to_kv_types);
+        return [kv_wrappers, from_kv_types, to_kv_types](
+                       FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                       const size_t result, size_t /*input_rows_count*/) -> 
Status {
+            auto& from_column = 
block.get_by_position(arguments.front()).column;
+            auto from_col_map = 
check_and_get_column<ColumnMap>(from_column.get());
+            if (!from_col_map) {
+                return Status::RuntimeError("Illegal column {} for function 
CAST AS MAP",
+                                            from_column->get_name());
+            }
+
+            Columns converted_columns(2);
+            ColumnsWithTypeAndName columnsWithTypeAndName(2);
+            columnsWithTypeAndName[0] = {from_col_map->get_keys_ptr(), 
from_kv_types[0], ""};
+            columnsWithTypeAndName[1] = {from_col_map->get_values_ptr(), 
from_kv_types[1], ""};
+
+            for (size_t i = 0; i < 2; ++i) {
+                ColumnNumbers element_arguments {block.columns()};
+                block.insert(columnsWithTypeAndName[i]);
+                size_t element_result = block.columns();
+                block.insert({to_kv_types[i], ""});
+                RETURN_IF_ERROR(kv_wrappers[i](context, block, 
element_arguments, element_result,
+                                               
columnsWithTypeAndName[i].column->size()));
+                converted_columns[i] = 
block.get_by_position(element_result).column;
+            }
+
+            block.get_by_position(result).column = ColumnMap::create(
+                    converted_columns[0], converted_columns[1], 
from_col_map->get_offsets_ptr());
+            return Status::OK();
+        };
     }
 
     ElementWrappers get_element_wrappers(FunctionContext* context,
@@ -2166,7 +2211,8 @@ private:
             return create_struct_wrapper(context, from_type,
                                          static_cast<const 
DataTypeStruct&>(*to_type));
         case TypeIndex::Map:
-            return create_map_wrapper(from_type, static_cast<const 
DataTypeMap&>(*to_type));
+            return create_map_wrapper(context, from_type,
+                                      static_cast<const 
DataTypeMap&>(*to_type));
         case TypeIndex::HLL:
             return create_hll_wrapper(context, from_type,
                                       static_cast<const 
DataTypeHLL&>(*to_type));
diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java 
b/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java
index e9efc83a8fc..691c4d7e04d 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java
@@ -192,8 +192,10 @@ public class MapType extends Type {
     }
 
     public static boolean canCastTo(MapType type, MapType targetType) {
-        return Type.canCastTo(type.getKeyType(), targetType.getKeyType())
-            && Type.canCastTo(type.getValueType(), targetType.getValueType());
+        return (targetType.getKeyType().isStringType() && 
type.getKeyType().isStringType()
+            || Type.canCastTo(type.getKeyType(), targetType.getKeyType()))
+            && (Type.canCastTo(type.getValueType(), targetType.getValueType())
+            || targetType.getValueType().isStringType() && 
type.getValueType().isStringType());
     }
 
     @Override
diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java 
b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
index 36d35f9e2a0..0d2c96ca855 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
@@ -2169,6 +2169,17 @@ public abstract class Type {
                     return false;
                 }
                 return matchExactType(((ArrayType) type2).getItemType(), 
((ArrayType) type1).getItemType());
+            } else if (type2.isMapType()) {
+                // For types array, we also need to check contains null for 
case like
+                // cast(array<not_null(int)> as array<int>)
+                if (!((MapType) type2).getIsKeyContainsNull() == ((MapType) 
type1).getIsKeyContainsNull()) {
+                    return false;
+                }
+                if (!((MapType) type2).getIsValueContainsNull() == ((MapType) 
type1).getIsValueContainsNull()) {
+                    return false;
+                }
+                return matchExactType(((MapType) type2).getKeyType(), 
((MapType) type1).getKeyType())
+                    && matchExactType(((MapType) type2).getValueType(), 
((MapType) type1).getValueType());
             } else {
                 return true;
             }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
index bf5c75c58f0..590c87427a0 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
@@ -268,7 +268,7 @@ public class CastExpr extends Expr {
         } else if (type.isMapType()) {
             fn = ScalarFunction.createBuiltin(getFnName(Type.MAP),
                     type, Function.NullableMode.ALWAYS_NULLABLE,
-                    Lists.newArrayList(Type.VARCHAR), false,
+                    
Lists.newArrayList(getActualArgTypes(collectChildReturnTypes())[0]), false,
                     "doris::CastFunctions::cast_to_map_val", null, null, true);
         } else if (type.isStructType()) {
             fn = ScalarFunction.createBuiltin(getFnName(Type.STRUCT),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index fe03afd02c3..1b32033ea88 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -28,6 +28,7 @@ import org.apache.doris.catalog.Env;
 import org.apache.doris.catalog.Function;
 import org.apache.doris.catalog.Function.NullableMode;
 import org.apache.doris.catalog.FunctionSet;
+import org.apache.doris.catalog.MapType;
 import org.apache.doris.catalog.MaterializedIndexMeta;
 import org.apache.doris.catalog.PrimitiveType;
 import org.apache.doris.catalog.ScalarFunction;
@@ -2516,6 +2517,8 @@ public abstract class Expr extends TreeNode<Expr> 
implements ParseNode, Cloneabl
             return getActualScalarType(originType);
         } else if (originType.getPrimitiveType() == PrimitiveType.ARRAY) {
             return getActualArrayType((ArrayType) originType);
+        } else if (originType.getPrimitiveType().isMapType()) {
+            return getActualMapType((MapType) originType);
         } else {
             return originType;
         }
@@ -2550,6 +2553,10 @@ public abstract class Expr extends TreeNode<Expr> 
implements ParseNode, Cloneabl
         return 
Arrays.stream(originType).map(this::getActualType).toArray(Type[]::new);
     }
 
+    private MapType getActualMapType(MapType originMapType) {
+        return new MapType(getActualType(originMapType.getKeyType()), 
getActualType(originMapType.getValueType()));
+    }
+
     private ArrayType getActualArrayType(ArrayType originArrayType) {
         return new ArrayType(getActualType(originArrayType.getItemType()));
     }
diff --git 
a/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out
 
b/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out
new file mode 100644
index 00000000000..c9da5a1c286
--- /dev/null
+++ 
b/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out
@@ -0,0 +1,48 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select --
+1      {"aa":1, "b":2, "1234567":77}
+2      {"b":12, "123":7777}
+
+-- !select --
+{}
+
+-- !select --
+{}
+
+-- !sql1 --
+\N
+
+-- !sql2 --
+{"":NULL}
+
+-- !sql3 --
+{"1":2}
+
+-- !sql4 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":7777}
+
+-- !sql5 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":7777}
+
+-- !sql6 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":97}
+
+-- !sql7 --
+{"aa":"1", "b":"2", "1234567":"77"}
+{"b":"12", "123":"7777"}
+
+-- !sql8 --
+{NULL:"1", NULL:"2", 1234567:"77"}
+{NULL:"12", 123:"7777"}
+
+-- !sql9 --
+{NULL:1, NULL:2, 1234567:77}
+{NULL:12, 123:7777}
+
+-- !sql10 --
+{NULL:NULL, NULL:NULL, 1234567:NULL}
+{NULL:NULL, 123:NULL}
+
diff --git 
a/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out
 
b/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out
new file mode 100644
index 00000000000..c9da5a1c286
--- /dev/null
+++ 
b/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out
@@ -0,0 +1,48 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select --
+1      {"aa":1, "b":2, "1234567":77}
+2      {"b":12, "123":7777}
+
+-- !select --
+{}
+
+-- !select --
+{}
+
+-- !sql1 --
+\N
+
+-- !sql2 --
+{"":NULL}
+
+-- !sql3 --
+{"1":2}
+
+-- !sql4 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":7777}
+
+-- !sql5 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":7777}
+
+-- !sql6 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":97}
+
+-- !sql7 --
+{"aa":"1", "b":"2", "1234567":"77"}
+{"b":"12", "123":"7777"}
+
+-- !sql8 --
+{NULL:"1", NULL:"2", 1234567:"77"}
+{NULL:"12", 123:"7777"}
+
+-- !sql9 --
+{NULL:1, NULL:2, 1234567:77}
+{NULL:12, 123:7777}
+
+-- !sql10 --
+{NULL:NULL, NULL:NULL, 1234567:NULL}
+{NULL:NULL, 123:NULL}
+
diff --git 
a/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy
 
b/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy
new file mode 100644
index 00000000000..e3a4ccdaecc
--- /dev/null
+++ 
b/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_cast_map_function", "query") {
+    sql """ set enable_nereids_planner = true; """
+    sql """ set enable_fallback_to_original_planner=false; """ 
+    def tableName = "tbl_test_cast_map_function_nereids"
+
+    sql """DROP TABLE IF EXISTS ${tableName}"""
+    sql """
+            CREATE TABLE IF NOT EXISTS ${tableName} (
+              `k1` int(11) NULL COMMENT "",
+              `k2` Map<char(7), int(11)> NOT NULL COMMENT "",
+            ) ENGINE=OLAP
+            DUPLICATE KEY(`k1`)
+            DISTRIBUTED BY HASH(`k1`) BUCKETS 1
+            PROPERTIES (
+            "replication_allocation" = "tag.location.default: 1",
+            "storage_format" = "V2"
+            )
+        """
+    // insert into with implicit cast
+    sql """ INSERT INTO ${tableName} VALUES(1, {"aa": 1, "b": 2, "1234567": 
77}) """
+    sql """ INSERT INTO ${tableName} VALUES(2, {"b":12, "123":7777}) """
+
+    qt_select """ select * from ${tableName} order by k1; """
+
+    qt_select " select cast({} as MAP<INT,INT>);"
+    qt_select " select cast(map() as MAP<INT,INT>); "
+    qt_sql1 "select cast(NULL as MAP<string,int>)"
+
+    // literal NONSTRICT_SUPERTYPE_OF cast
+    qt_sql2 "select cast({'':''} as MAP<String,INT>);"
+    qt_sql3 "select cast({1:2} as MAP<String,INT>);"
+
+    // select SUPERTYPE_OF cast
+    qt_sql4 "select cast(k2 as map<varchar, bigint>) from ${tableName} order 
by k1;"
+
+    // select NONSTRICT_SUPERTYPE_OF cast , this behavior is same with nested 
scala type
+    qt_sql5 "select cast(k2 as map<char(2), smallint>) from ${tableName} order 
by k1;"
+    qt_sql6 "select cast(k2 as map<char(1), tinyint>) from ${tableName} order 
by k1;"
+    qt_sql7 "select cast(k2 as map<char, string>) from ${tableName} order by 
k1;"
+    qt_sql8 "select cast(k2 as map<int, string>) from ${tableName} order by 
k1;"
+    qt_sql9 "select cast(k2 as map<largeint, decimal>) from ${tableName} order 
by k1;"
+    qt_sql10 "select cast(k2 as map<double, datetime>) from ${tableName} order 
by k1;"
+}
diff --git 
a/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy
 
b/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy
new file mode 100644
index 00000000000..021f8096b04
--- /dev/null
+++ 
b/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_cast_map_function", "query") {
+    sql """set enable_nereids_planner = false """
+    def tableName = "tbl_test_cast_map_function"
+    // array functions only supported in vectorized engine
+
+    sql """DROP TABLE IF EXISTS ${tableName}"""
+    sql """
+            CREATE TABLE IF NOT EXISTS ${tableName} (
+              `k1` int(11) NULL COMMENT "",
+              `k2` Map<char(7), int(11)> NOT NULL COMMENT "",
+            ) ENGINE=OLAP
+            DUPLICATE KEY(`k1`)
+            DISTRIBUTED BY HASH(`k1`) BUCKETS 1
+            PROPERTIES (
+            "replication_allocation" = "tag.location.default: 1",
+            "storage_format" = "V2"
+            )
+        """
+    // insert into with implicit cast
+    sql """ INSERT INTO ${tableName} VALUES(1, {"aa": 1, "b": 2, "1234567": 
77}) """
+    sql """ INSERT INTO ${tableName} VALUES(2, {"b":12, "123":7777}) """
+
+    qt_select """ select * from ${tableName} order by k1; """
+
+    qt_select " select cast({} as MAP<INT,INT>);"
+    qt_select " select cast(map() as MAP<INT,INT>); "
+    qt_sql1 "select cast(NULL as MAP<string,int>)"
+
+    // literal NONSTRICT_SUPERTYPE_OF cast
+    qt_sql2 "select cast({'':''} as MAP<String,INT>);"
+    qt_sql3 "select cast({1:2} as MAP<String,INT>);"
+
+    // select SUPERTYPE_OF cast
+    qt_sql4 "select cast(k2 as map<varchar, bigint>) from ${tableName} order 
by k1;"
+
+    // select NONSTRICT_SUPERTYPE_OF cast , this behavior is same with nested 
scala type
+    qt_sql5 "select cast(k2 as map<char(2), smallint>) from ${tableName} order 
by k1;"
+    qt_sql6 "select cast(k2 as map<char(1), tinyint>) from ${tableName} order 
by k1;"
+    qt_sql7 "select cast(k2 as map<char, string>) from ${tableName} order by 
k1;"
+    qt_sql8 "select cast(k2 as map<int, string>) from ${tableName} order by 
k1;"
+    qt_sql9 "select cast(k2 as map<largeint, decimal>) from ${tableName} order 
by k1;"
+    qt_sql10 "select cast(k2 as map<double, datetime>) from ${tableName} order 
by k1;"
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to