zclllyybb commented on code in PR #54036:
URL: https://github.com/apache/doris/pull/54036#discussion_r2337405266
##########
be/src/vec/functions/function_map.cpp:
##########
@@ -440,13 +491,409 @@ class FunctionStrToMap : public IFunction {
}
};
+class FunctionMapContainsEntry : public IFunction {
+public:
+ static constexpr auto name = "map_contains_entry";
+ static FunctionPtr create() { return
std::make_shared<FunctionMapContainsEntry>(); }
+
+ String get_name() const override { return name; }
+ size_t get_number_of_arguments() const override { return 3; }
+ bool use_default_implementation_for_nulls() const override { return false;
}
+
+ DataTypePtr get_return_type_impl(const DataTypes& arguments) const
override {
+ DataTypePtr datatype = arguments[0];
+ if (datatype->is_nullable()) {
+ datatype = assert_cast<const
DataTypeNullable*>(datatype.get())->get_nested_type();
+ }
+ DCHECK(datatype->get_primitive_type() == TYPE_MAP)
+ << "first argument for function: " << name << " should be
DataTypeMap";
+
+ if (arguments[0]->is_nullable()) {
+ return make_nullable(std::make_shared<DataTypeBool>());
+ } else {
+ return std::make_shared<DataTypeBool>();
+ }
+ }
+
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ uint32_t result, size_t input_rows_count) const
override {
+ return _execute_type_check_and_dispatch(block, arguments, result);
+ }
+
+private:
+ // assume result_matches is initialized to all 1s
+ template <typename ColumnType>
+ void _execute_column_comparison(const IColumn& map_entry_column, const
UInt8* map_entry_nullmap,
+ const IColumn& search_column, const UInt8*
search_nullmap,
+ const ColumnArray::Offsets64& map_offsets,
+ const UInt8* map_row_nullmap, bool
search_is_const,
+ ColumnUInt8& result_matches) const {
+ auto& result_data = result_matches.get_data();
+ for (size_t row = 0; row < map_offsets.size(); ++row) {
+ if (map_row_nullmap && map_row_nullmap[row]) {
+ continue;
+ }
+ size_t map_start = row == 0 ? 0 : map_offsets[row - 1];
+ size_t map_end = map_offsets[row];
+ // const column always uses index 0
+ size_t search_idx = search_is_const ? 0 : row;
+ for (size_t i = map_start; i < map_end; ++i) {
+ result_data[i] &=
+ compare_values<ColumnType>(map_entry_column, i,
map_entry_nullmap,
+ search_column, search_idx,
search_nullmap)
+ ? 1
+ : 0;
+ }
+ }
+ }
+
+ // dispatch column comparison by type, map_entry_column is the column of
map's key or value, search_column is the column of search key or value
+ void _dispatch_column_comparison(PrimitiveType type, const IColumn&
map_entry_column,
+ const UInt8* map_entry_nullmap, const
IColumn& search_column,
+ const UInt8* search_nullmap,
+ const ColumnArray::Offsets64& map_offsets,
+ const UInt8* map_row_nullmap, bool
search_is_const,
+ ColumnUInt8& result_matches) const {
+ switch (type) {
+ case TYPE_BOOLEAN:
+ _execute_column_comparison<ColumnUInt8>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_TINYINT:
+ _execute_column_comparison<ColumnInt8>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_SMALLINT:
+ _execute_column_comparison<ColumnInt16>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_INT:
+ _execute_column_comparison<ColumnInt32>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_BIGINT:
+ _execute_column_comparison<ColumnInt64>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_LARGEINT:
+ _execute_column_comparison<ColumnInt128>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_FLOAT:
+ _execute_column_comparison<ColumnFloat32>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DOUBLE:
+ _execute_column_comparison<ColumnFloat64>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DECIMAL32:
+ _execute_column_comparison<ColumnDecimal32>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DECIMAL64:
+ _execute_column_comparison<ColumnDecimal64>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DECIMAL128I:
+ _execute_column_comparison<ColumnDecimal128V3>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DECIMALV2:
+ _execute_column_comparison<ColumnDecimal128V2>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DECIMAL256:
+ _execute_column_comparison<ColumnDecimal256>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_STRING:
+ case TYPE_CHAR:
+ case TYPE_VARCHAR:
+ _execute_column_comparison<ColumnString>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DATE:
+ _execute_column_comparison<ColumnDate>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DATETIME:
+ _execute_column_comparison<ColumnDateTime>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DATEV2:
+ _execute_column_comparison<ColumnDateV2>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_DATETIMEV2:
+ _execute_column_comparison<ColumnDateTimeV2>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_TIME:
+ _execute_column_comparison<ColumnTime>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_TIMEV2:
+ _execute_column_comparison<ColumnTimeV2>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_IPV4:
+ _execute_column_comparison<ColumnIPv4>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ case TYPE_IPV6:
+ _execute_column_comparison<ColumnIPv6>(
+ map_entry_column, map_entry_nullmap, search_column,
search_nullmap, map_offsets,
+ map_row_nullmap, search_is_const, result_matches);
+ break;
+ default:
+ // We have done type check before dispatching, so this should not
happen
+ DCHECK(false) << "Dispatching unsupported primitive type in " <<
get_name() << ": "
+ << static_cast<int>(type);
+ break;
+ }
+ }
+
+ // main loop function
+ ColumnPtr _execute_all_rows(const ColumnMap* map_column, const ColumnPtr&
map_row_nullmap_col,
+ const IColumn& key_column, const UInt8*
key_nullmap,
+ const IColumn& value_column, const UInt8*
value_nullmap,
+ PrimitiveType key_type, PrimitiveType
value_type, bool key_is_const,
+ bool value_is_const) const {
+ const auto& map_offsets = map_column->get_offsets();
+
+ // remove the nullable wrapper of map's key and value
+ const auto& map_keys_nullable =
+ reinterpret_cast<const
ColumnNullable&>(map_column->get_keys());
+ const IColumn* map_keys_column =
&map_keys_nullable.get_nested_column();
+ const auto& map_keys_nullmap =
map_keys_nullable.get_null_map_column().get_data().data();
+
+ const auto& map_values_nullable =
+ reinterpret_cast<const
ColumnNullable&>(map_column->get_values());
+ const IColumn* map_values_column =
&map_values_nullable.get_nested_column();
+ const auto& map_values_nullmap =
+ map_values_nullable.get_null_map_column().get_data().data();
+
+ auto result_column = ColumnUInt8::create(map_offsets.size(), 0);
+ auto& result_data = result_column->get_data();
+
+ const UInt8* map_row_nullmap = nullptr;
+ if (map_row_nullmap_col) {
+ map_row_nullmap =
+ assert_cast<const
ColumnUInt8&>(*map_row_nullmap_col).get_data().data();
+ }
+
+ auto matches = ColumnUInt8::create(map_keys_column->size(), 1);
+
+ // matches &= key_compare
+ _dispatch_column_comparison(key_type, *map_keys_column,
map_keys_nullmap, key_column,
+ key_nullmap, map_offsets, map_row_nullmap,
key_is_const,
+ *matches);
+
+ // matches &= value_compare
+ _dispatch_column_comparison(value_type, *map_values_column,
map_values_nullmap,
+ value_column, value_nullmap, map_offsets,
map_row_nullmap,
+ value_is_const, *matches);
+
+ // aggregate results by map boundaries
+ auto& matches_data = matches->get_data();
+ for (size_t row = 0; row < map_offsets.size(); ++row) {
+ if (map_row_nullmap && map_row_nullmap[row]) {
+ // result is null for this row
+ continue;
+ }
+
+ size_t map_start = row == 0 ? 0 : map_offsets[row - 1];
+ size_t map_end = map_offsets[row];
+
+ bool found = false;
+ for (size_t i = map_start; i < map_end && !found; ++i) {
+ if (matches_data[i]) {
+ found = true;
+ break;
+ }
+ }
+ result_data[row] = found;
+ }
+
+ if (map_row_nullmap_col) {
+ return ColumnNullable::create(std::move(result_column),
map_row_nullmap_col);
+ }
+ return result_column;
+ }
+
+ // type comparability check and dispatch
+ Status _execute_type_check_and_dispatch(Block& block, const ColumnNumbers&
arguments,
+ uint32_t result) const {
+ // get type information
+ auto map_type =
remove_nullable(block.get_by_position(arguments[0]).type);
+ const auto* map_datatype = assert_cast<const
DataTypeMap*>(map_type.get());
+ auto map_key_type = remove_nullable(map_datatype->get_key_type());
+ auto map_value_type = remove_nullable(map_datatype->get_value_type());
+ auto search_key_type =
remove_nullable(block.get_by_position(arguments[1]).type);
+ auto search_value_type =
remove_nullable(block.get_by_position(arguments[2]).type);
+
+ bool key_types_comparable =
type_comparable(map_key_type->get_primitive_type(),
Review Comment:
In FE we use `AnyDataType` with same slot or `FollowToAnyDataType`. they
both will inference type to same. so I think
`map_key_type->get_primitive_type()` and
`search_key_type->get_primitive_type()` here should be same type?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]