This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 3f2a5145a3567fa864fe0a4a8d72573666179321
Author: lihangyu <[email protected]>
AuthorDate: Fri Aug 23 10:22:24 2024 +0800

    [Fix](Variant) casting to decimal type may lost precision and scale when 
_normalize_predicate (#39650)
    
    use TypeDescriptor to prevent from info lost
---
 be/src/olap/iterators.h                            |  2 +-
 be/src/olap/rowset/rowset_reader_context.h         |  2 +-
 be/src/olap/rowset/segment_v2/segment_iterator.cpp |  2 +-
 be/src/olap/tablet_reader.cpp                      |  3 +-
 be/src/olap/tablet_reader.h                        |  2 +-
 be/src/pipeline/exec/scan_operator.cpp             | 38 +++++++++++-----------
 be/src/pipeline/exec/scan_operator.h               |  5 +--
 .../data/variant_p0/sql/implicit_cast.out          | 12 +++++++
 .../suites/variant_p0/sql/implicit_cast.sql        |  4 ++-
 9 files changed, 43 insertions(+), 27 deletions(-)

diff --git a/be/src/olap/iterators.h b/be/src/olap/iterators.h
index cbf8f1eca65..dd1e9528fa3 100644
--- a/be/src/olap/iterators.h
+++ b/be/src/olap/iterators.h
@@ -117,7 +117,7 @@ public:
     Version version;
     int64_t tablet_id = 0;
     // slots that cast may be eliminated in storage layer
-    std::map<std::string, PrimitiveType> target_cast_type_for_variants;
+    std::map<std::string, TypeDescriptor> target_cast_type_for_variants;
     RowRanges row_ranges;
     size_t topn_limit = 0;
 };
diff --git a/be/src/olap/rowset/rowset_reader_context.h 
b/be/src/olap/rowset/rowset_reader_context.h
index 59abf85fb72..2de76df8040 100644
--- a/be/src/olap/rowset/rowset_reader_context.h
+++ b/be/src/olap/rowset/rowset_reader_context.h
@@ -81,7 +81,7 @@ struct RowsetReaderContext {
     const std::set<int32_t>* output_columns = nullptr;
     RowsetId rowset_id;
     // slots that cast may be eliminated in storage layer
-    std::map<std::string, PrimitiveType> target_cast_type_for_variants;
+    std::map<std::string, TypeDescriptor> target_cast_type_for_variants;
     int64_t ttl_seconds = 0;
     size_t topn_limit = 0;
 };
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp 
b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
index 66c7867991e..79ab794a9a0 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
@@ -1884,7 +1884,7 @@ bool 
SegmentIterator::_can_evaluated_by_vectorized(ColumnPredicate* predicate) {
     if (field_type == FieldType::OLAP_FIELD_TYPE_VARIANT) {
         // Use variant cast dst type
         field_type = TabletColumn::get_field_type_by_type(
-                
_opts.target_cast_type_for_variants[_schema->column(cid)->name()]);
+                
_opts.target_cast_type_for_variants[_schema->column(cid)->name()].type);
     }
     switch (predicate->type()) {
     case PredicateType::EQ:
diff --git a/be/src/olap/tablet_reader.cpp b/be/src/olap/tablet_reader.cpp
index 631a041379a..d158e6a5ac5 100644
--- a/be/src/olap/tablet_reader.cpp
+++ b/be/src/olap/tablet_reader.cpp
@@ -273,7 +273,8 @@ TabletColumn TabletReader::materialize_column(const 
TabletColumn& orig) {
     }
     TabletColumn column_with_cast_type = orig;
     auto cast_type = 
_reader_context.target_cast_type_for_variants.at(orig.name());
-    
column_with_cast_type.set_type(TabletColumn::get_field_type_by_type(cast_type));
+    FieldType filed_type = 
TabletColumn::get_field_type_by_type(cast_type.type);
+    column_with_cast_type.set_type(filed_type);
     return column_with_cast_type;
 }
 
diff --git a/be/src/olap/tablet_reader.h b/be/src/olap/tablet_reader.h
index 06c3daa653a..18ebb9653cc 100644
--- a/be/src/olap/tablet_reader.h
+++ b/be/src/olap/tablet_reader.h
@@ -136,7 +136,7 @@ public:
         std::vector<FunctionFilter> function_filters;
         std::vector<RowsetMetaSharedPtr> delete_predicates;
         // slots that cast may be eliminated in storage layer
-        std::map<std::string, PrimitiveType> target_cast_type_for_variants;
+        std::map<std::string, TypeDescriptor> target_cast_type_for_variants;
 
         std::vector<RowSetSplits> rs_splits;
         // For unique key table with merge-on-write
diff --git a/be/src/pipeline/exec/scan_operator.cpp 
b/be/src/pipeline/exec/scan_operator.cpp
index 73d237af7e9..42b23baa101 100644
--- a/be/src/pipeline/exec/scan_operator.cpp
+++ b/be/src/pipeline/exec/scan_operator.cpp
@@ -30,6 +30,7 @@
 #include "pipeline/exec/meta_scan_operator.h"
 #include "pipeline/exec/olap_scan_operator.h"
 #include "pipeline/exec/operator.h"
+#include "runtime/types.h"
 #include "util/runtime_profile.h"
 #include "vec/exprs/vcast_expr.h"
 #include "vec/exprs/vcompound_pred.h"
@@ -125,14 +126,14 @@ Status 
ScanLocalState<Derived>::_normalize_conjuncts(RuntimeState* state) {
     // The conjuncts is always on output tuple, so use _output_tuple_desc;
     std::vector<SlotDescriptor*> slots = p._output_tuple_desc->slots();
 
-    auto init_value_range = [&](SlotDescriptor* slot, PrimitiveType type) {
-        switch (type) {
-#define M(NAME)                                                                
          \
-    case TYPE_##NAME: {                                                        
          \
-        ColumnValueRange<TYPE_##NAME> range(slot->col_name(), 
slot->is_nullable(),       \
-                                            slot->type().precision, 
slot->type().scale); \
-        _slot_id_to_value_range[slot->id()] = std::pair {slot, range};         
          \
-        break;                                                                 
          \
+    auto init_value_range = [&](SlotDescriptor* slot, const TypeDescriptor& 
type_desc) {
+        switch (type_desc.type) {
+#define M(NAME)                                                                
    \
+    case TYPE_##NAME: {                                                        
    \
+        ColumnValueRange<TYPE_##NAME> range(slot->col_name(), 
slot->is_nullable(), \
+                                            type_desc.precision, 
type_desc.scale); \
+        _slot_id_to_value_range[slot->id()] = std::pair {slot, range};         
    \
+        break;                                                                 
    \
     }
 #define APPLY_FOR_PRIMITIVE_TYPE(M) \
     M(TINYINT)                      \
@@ -173,7 +174,7 @@ Status 
ScanLocalState<Derived>::_normalize_conjuncts(RuntimeState* state) {
                 continue;
             }
         }
-        init_value_range(slot, slot->type().type);
+        init_value_range(slot, slot->type());
     }
 
     get_cast_types_for_variants();
@@ -587,7 +588,7 @@ Status 
ScanLocalState<Derived>::_normalize_in_and_eq_predicate(vectorized::VExpr
                                                                
ColumnValueRange<T>& range,
                                                                PushDownType* 
pdt) {
     auto temp_range = ColumnValueRange<T>::create_empty_column_value_range(
-            slot->is_nullable(), slot->type().precision, slot->type().scale);
+            slot->is_nullable(), range.precision(), range.scale());
     // 1. Normalize in conjuncts like 'where col in (v1, v2, v3)'
     if (TExprNodeType::IN_PRED == expr->node_type()) {
         HybridSetBase::IteratorBase* iter = nullptr;
@@ -741,7 +742,7 @@ Status 
ScanLocalState<Derived>::_normalize_not_in_and_not_eq_predicate(
         ColumnValueRange<T>& range, PushDownType* pdt) {
     bool is_fixed_range = range.is_fixed_value_range();
     auto not_in_range = ColumnValueRange<T>::create_empty_column_value_range(
-            range.column_name(), slot->is_nullable(), slot->type().precision, 
slot->type().scale);
+            range.column_name(), slot->is_nullable(), range.precision(), 
range.scale());
     PushDownType temp_pdt = PushDownType::UNACCEPTABLE;
     // 1. Normalize in conjuncts like 'where col in (v1, v2, v3)'
     if (TExprNodeType::IN_PRED == expr->node_type()) {
@@ -924,14 +925,14 @@ Status 
ScanLocalState<Derived>::_normalize_is_null_predicate(vectorized::VExpr*
         if 
(reinterpret_cast<vectorized::VectorizedFnCall*>(expr)->fn().name.function_name 
==
             "is_null_pred") {
             auto temp_range = 
ColumnValueRange<T>::create_empty_column_value_range(
-                    slot->is_nullable(), slot->type().precision, 
slot->type().scale);
+                    slot->is_nullable(), range.precision(), range.scale());
             temp_range.set_contain_null(true);
             range.intersection(temp_range);
             *pdt = temp_pdt;
         } else if 
(reinterpret_cast<vectorized::VectorizedFnCall*>(expr)->fn().name.function_name 
==
                    "is_not_null_pred") {
             auto temp_range = 
ColumnValueRange<T>::create_empty_column_value_range(
-                    slot->is_nullable(), slot->type().precision, 
slot->type().scale);
+                    slot->is_nullable(), range.precision(), range.scale());
             temp_range.set_contain_null(false);
             range.intersection(temp_range);
             *pdt = temp_pdt;
@@ -1171,7 +1172,7 @@ Status 
ScanLocalState<Derived>::_normalize_match_predicate(vectorized::VExpr* ex
 
         // create empty range as temp range, temp range should do intersection 
on range
         auto temp_range = ColumnValueRange<T>::create_empty_column_value_range(
-                slot->is_nullable(), slot->type().precision, 
slot->type().scale);
+                slot->is_nullable(), range.precision(), range.scale());
         // Normalize match conjuncts like 'where col match value'
 
         auto match_checker = [](const std::string& fn_name) { return 
is_match_condition(fn_name); };
@@ -1334,7 +1335,7 @@ Status 
ScanLocalState<Derived>::_get_topn_filters(RuntimeState* state) {
 template <typename Derived>
 void ScanLocalState<Derived>::_filter_and_collect_cast_type_for_variant(
         const vectorized::VExpr* expr,
-        phmap::flat_hash_map<std::string, std::vector<PrimitiveType>>& 
colname_to_cast_types) {
+        std::unordered_map<std::string, std::vector<TypeDescriptor>>& 
colname_to_cast_types) {
     auto& p = _parent->cast<typename Derived::Parent>();
     const auto* cast_expr = dynamic_cast<const vectorized::VCastExpr*>(expr);
     if (cast_expr != nullptr) {
@@ -1347,10 +1348,9 @@ void 
ScanLocalState<Derived>::_filter_and_collect_cast_type_for_variant(
         }
         std::vector<SlotDescriptor*> slots = output_tuple_desc()->slots();
         SlotDescriptor* src_slot_desc = 
p._slot_id_to_slot_desc[src_slot->slot_id()];
-        PrimitiveType cast_dst_type =
-                
cast_expr->get_target_type()->get_type_as_type_descriptor().type;
+        TypeDescriptor type_desc = 
cast_expr->get_target_type()->get_type_as_type_descriptor();
         if (src_slot_desc->type().is_variant_type()) {
-            
colname_to_cast_types[src_slot_desc->col_name()].push_back(cast_dst_type);
+            
colname_to_cast_types[src_slot_desc->col_name()].push_back(type_desc);
         }
     }
     for (const auto& child : expr->children()) {
@@ -1360,7 +1360,7 @@ void 
ScanLocalState<Derived>::_filter_and_collect_cast_type_for_variant(
 
 template <typename Derived>
 void ScanLocalState<Derived>::get_cast_types_for_variants() {
-    phmap::flat_hash_map<std::string, std::vector<PrimitiveType>> 
colname_to_cast_types;
+    std::unordered_map<std::string, std::vector<TypeDescriptor>> 
colname_to_cast_types;
     for (auto it = _conjuncts.begin(); it != _conjuncts.end();) {
         auto& conjunct = *it;
         if (conjunct->root()) {
diff --git a/be/src/pipeline/exec/scan_operator.h 
b/be/src/pipeline/exec/scan_operator.h
index cbbeb75998d..43b5c7bb921 100644
--- a/be/src/pipeline/exec/scan_operator.h
+++ b/be/src/pipeline/exec/scan_operator.h
@@ -28,6 +28,7 @@
 #include "pipeline/common/runtime_filter_consumer.h"
 #include "pipeline/dependency.h"
 #include "runtime/descriptors.h"
+#include "runtime/types.h"
 #include "vec/exec/scan/vscan_node.h"
 #include "vec/exprs/vectorized_fn_call.h"
 #include "vec/exprs/vin_predicate.h"
@@ -340,7 +341,7 @@ protected:
     void get_cast_types_for_variants();
     void _filter_and_collect_cast_type_for_variant(
             const vectorized::VExpr* expr,
-            phmap::flat_hash_map<std::string, std::vector<PrimitiveType>>& 
colname_to_cast_types);
+            std::unordered_map<std::string, std::vector<TypeDescriptor>>& 
colname_to_cast_types);
 
     Status _get_topn_filters(RuntimeState* state);
 
@@ -357,7 +358,7 @@ protected:
     std::vector<FunctionFilter> _push_down_functions;
 
     // colname -> cast dst type
-    std::map<std::string, PrimitiveType> _cast_types_for_variants;
+    std::map<std::string, TypeDescriptor> _cast_types_for_variants;
 
     // slot id -> ColumnValueRange
     // Parsed from conjuncts
diff --git a/regression-test/data/variant_p0/sql/implicit_cast.out 
b/regression-test/data/variant_p0/sql/implicit_cast.out
index b0f5d96087b..2eefddc43e5 100644
--- a/regression-test/data/variant_p0/sql/implicit_cast.out
+++ b/regression-test/data/variant_p0/sql/implicit_cast.out
@@ -78,3 +78,15 @@ user
 user
 user
 
+-- !implicit_cast_14 --
+14690746673
+14690746676
+14690746679
+14690746680
+14690746681
+14690746684
+14690746685
+14690746687
+14690746688
+14690746689
+
diff --git a/regression-test/suites/variant_p0/sql/implicit_cast.sql 
b/regression-test/suites/variant_p0/sql/implicit_cast.sql
index 45acfd38513..416e2616ea0 100644
--- a/regression-test/suites/variant_p0/sql/implicit_cast.sql
+++ b/regression-test/suites/variant_p0/sql/implicit_cast.sql
@@ -12,4 +12,6 @@ SELECT v["payload"]["member"]["id"] FROM ghdata where 
v["payload"]["member"]["id
 select k, json_extract(v, '$.repo') from ghdata WHERE v["type"] = 'WatchEvent' 
 order by k limit 10;
 -- SELECT v["payload"]["member"]["id"], count() FROM ghdata where 
v["payload"]["member"]["id"] is not null group by v["payload"]["member"]["id"] 
order by 1, 2 desc LIMIT 10;
 select k, v["id"], v["type"], v["repo"]["name"] from ghdata WHERE v["type"] = 
'WatchEvent'  order by k limit 10;
-SELECT v["payload"]["pusher_type"] FROM ghdata where 
v["payload"]["pusher_type"] is not null ORDER BY k LIMIT 10;
\ No newline at end of file
+SELECT v["payload"]["pusher_type"] FROM ghdata where 
v["payload"]["pusher_type"] is not null ORDER BY k LIMIT 10;
+-- implicit cast to decimal type
+SELECT v["id"] FROM ghdata where v["id"] not in (7273, 10.118626, -69352) 
order by cast(v["id"] as decimal) limit 10;
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to