This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 015371a  [fix](grouping-set) Fix the bug of grouping set core in both 
vec and non vec query engine (#7800)
015371a is described below

commit 015371ac7267aef141a5f46eb2d4ddca59873485
Author: HappenLee <[email protected]>
AuthorDate: Wed Jan 26 16:15:30 2022 +0800

    [fix](grouping-set) Fix the bug of grouping set core in both vec and non 
vec query engine (#7800)
---
 be/src/vec/common/columns_hashing.h                |  2 +-
 be/src/vec/exec/vaggregation_node.cpp              | 28 ++++++++++++++--
 be/src/vec/exec/vaggregation_node.h                |  5 +++
 be/src/vec/exec/vrepeat_node.cpp                   | 19 ++++-------
 be/src/vec/functions/function_grouping.h           | 38 +++++-----------------
 .../apache/doris/analysis/AggregateInfoBase.java   |  9 +++++
 .../java/org/apache/doris/analysis/Analyzer.java   |  5 +--
 .../java/org/apache/doris/planner/RepeatNode.java  |  3 +-
 .../apache/doris/planner/SingleNodePlanner.java    | 10 ------
 9 files changed, 59 insertions(+), 60 deletions(-)

diff --git a/be/src/vec/common/columns_hashing.h 
b/be/src/vec/common/columns_hashing.h
index 26bc0d5..75e60b7 100644
--- a/be/src/vec/common/columns_hashing.h
+++ b/be/src/vec/common/columns_hashing.h
@@ -196,7 +196,7 @@ struct HashMethodSingleLowNullableColumn : public 
SingleColumnMethod {
 
     ColumnRawPtrs key_columns;
 
-    static const ColumnRawPtrs get_nested_column(const IColumn *col) {
+    static const ColumnRawPtrs get_nested_column(const IColumn* col) {
         auto* nullable = check_and_get_column<ColumnNullable>(*col);
         DCHECK(nullable != nullptr);
         const auto nested_col = nullable->get_nested_column_ptr().get();
diff --git a/be/src/vec/exec/vaggregation_node.cpp 
b/be/src/vec/exec/vaggregation_node.cpp
index ed8a1ba..1108a09 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -219,6 +219,14 @@ Status AggregationNode::prepare(RuntimeState* state) {
     _mem_pool = std::make_unique<MemPool>(mem_tracker().get());
 
     int j = _probe_expr_ctxs.size();
+    for (int i = 0; i < j; ++i) {
+        auto nullable_output = _output_tuple_desc->slots()[i]->is_nullable();
+        auto nullable_input = _probe_expr_ctxs[i]->root()->is_nullable();
+        if (nullable_output != nullable_input) {
+            DCHECK(nullable_output);
+            _make_nullable_keys.emplace_back(i);
+        }
+    }
     for (int i = 0; i < _aggregate_evaluators.size(); ++i, ++j) {
         SlotDescriptor* intermediate_slot_desc = 
_intermediate_tuple_desc->slots()[j];
         SlotDescriptor* output_slot_desc = _output_tuple_desc->slots()[j];
@@ -377,9 +385,11 @@ Status AggregationNode::get_next(RuntimeState* state, 
Block* block, bool* eos) {
         }
         // pre stream agg need use _num_row_return to decide whether to do pre 
stream agg
         _num_rows_returned += block->rows();
-        if (*eos) COUNTER_SET(_rows_returned_counter, _num_rows_returned);
+        _make_nullable_output_key(block);
+        COUNTER_SET(_rows_returned_counter, _num_rows_returned);
     } else {
         RETURN_IF_ERROR(_executor.get_result(state, block, eos));
+        _make_nullable_output_key(block);
         // dispose the having clause, should not be execute in prestreaming agg
         RETURN_IF_ERROR(VExprContext::filter_block(_vconjunct_ctx_ptr, block, 
block->columns()));
         reached_limit(block, eos);
@@ -556,6 +566,17 @@ void AggregationNode::_close_without_key() {
     release_tracker();
 }
 
+void AggregationNode::_make_nullable_output_key(Block* block) {
+    if (block->rows() != 0) {
+        for (auto cid : _make_nullable_keys) {
+            block->get_by_position(cid).column =
+                    make_nullable(block->get_by_position(cid).column);
+            block->get_by_position(cid).type =
+                    make_nullable(block->get_by_position(cid).type);
+        }
+    }
+}
+
 bool AggregationNode::_should_expand_preagg_hash_tables() {
     if (!_should_expand_hash_table) return false;
 
@@ -707,7 +728,8 @@ Status 
AggregationNode::_pre_agg_with_serialized_key(doris::vectorized::Block* i
                             for (int i = 0; i < key_size; ++i) {
                                 columns_with_schema.emplace_back(
                                         key_columns[i]->clone_resized(rows),
-                                        
_probe_expr_ctxs[i]->root()->data_type(), "");
+                                        
_probe_expr_ctxs[i]->root()->data_type(),
+                                        
_probe_expr_ctxs[i]->root()->expr_name());
                             }
                             for (int i = 0; i < value_columns.size(); ++i) {
                                 
columns_with_schema.emplace_back(std::move(value_columns[i]),
@@ -979,7 +1001,7 @@ Status 
AggregationNode::_serialize_with_serialized_key_result(RuntimeState* stat
         ColumnsWithTypeAndName columns_with_schema;
         for (int i = 0; i < key_size; ++i) {
             columns_with_schema.emplace_back(std::move(key_columns[i]),
-                                             
_probe_expr_ctxs[i]->root()->data_type(), "");
+                                             
_probe_expr_ctxs[i]->root()->data_type(), 
_probe_expr_ctxs[i]->root()->expr_name());
         }
         for (int i = 0; i < agg_size; ++i) {
             columns_with_schema.emplace_back(std::move(value_columns[i]), 
value_data_types[i], "");
diff --git a/be/src/vec/exec/vaggregation_node.h 
b/be/src/vec/exec/vaggregation_node.h
index 45f1d59..1864d31 100644
--- a/be/src/vec/exec/vaggregation_node.h
+++ b/be/src/vec/exec/vaggregation_node.h
@@ -388,6 +388,9 @@ public:
 private:
     // group by k1,k2
     std::vector<VExprContext*> _probe_expr_ctxs;
+    // left / full join will change the key nullable make output/input solt
+    // nullable diff. so we need make nullable of it.
+    std::vector<size_t> _make_nullable_keys;
     std::vector<size_t> _probe_key_sz;
 
     std::vector<AggFnEvaluator*> _aggregate_evaluators;
@@ -433,6 +436,8 @@ private:
     /// the preagg should pass through any rows it can't fit in its tables.
     bool _should_expand_preagg_hash_tables();
 
+    void _make_nullable_output_key(Block* block);
+
     Status _create_agg_status(AggregateDataPtr data);
     Status _destory_agg_status(AggregateDataPtr data);
 
diff --git a/be/src/vec/exec/vrepeat_node.cpp b/be/src/vec/exec/vrepeat_node.cpp
index 287aa6e..f27b4f5 100644
--- a/be/src/vec/exec/vrepeat_node.cpp
+++ b/be/src/vec/exec/vrepeat_node.cpp
@@ -131,7 +131,7 @@ Status VRepeatNode::get_repeated_block(Block* child_block, 
int repeat_id_idx, Bl
         cur_col++;
     }
 
-    // Fill grouping ID to tuple
+    // Fill grouping ID to block
     for (auto slot_idx = 0; slot_idx < _grouping_list.size(); slot_idx++) {
         DCHECK_LT(slot_idx, _virtual_tuple_desc->slots().size());
         const SlotDescriptor* _virtual_slot_desc = 
_virtual_tuple_desc->slots()[slot_idx];
@@ -139,21 +139,13 @@ Status VRepeatNode::get_repeated_block(Block* 
child_block, int repeat_id_idx, Bl
         DCHECK_EQ(_virtual_slot_desc->col_name(), 
_output_slots[cur_col]->col_name());
         int64_t val = _grouping_list[slot_idx][repeat_id_idx];
         auto* column_ptr = columns[cur_col].get();
-        if (_output_slots[cur_col]->is_nullable()) {
-            auto* nullable_column = reinterpret_cast<ColumnNullable 
*>(columns[cur_col].get());
-            auto& null_map = nullable_column->get_null_map_data();
-            column_ptr = &nullable_column->get_nested_column();
-
-            for (size_t i = 0; i < child_block->rows(); ++i) {
-                null_map.push_back(0);
-            }
-        }
+        DCHECK(!_output_slots[cur_col]->is_nullable());
 
         auto* col = assert_cast<ColumnVector<Int64> *>(column_ptr);
         for (size_t i = 0; i < child_block->rows(); ++i) {
             col->insert_value(val);
         }
-        cur_col ++;
+        cur_col++;
     }
 
     DCHECK_EQ(cur_col, column_size);
@@ -194,9 +186,10 @@ Status VRepeatNode::get_next(RuntimeState* state, Block* 
block, bool* eos) {
             return Status::OK();
         }
 
-        RETURN_IF_ERROR(child(0)->get_next(state, _child_block.get(), 
&_child_eos));
+        while (_child_block->rows() == 0 && ! _child_eos)
+            RETURN_IF_ERROR(child(0)->get_next(state, _child_block.get(), 
&_child_eos));
 
-        if (_child_block->rows() == 0) {
+        if (_child_eos and _child_block->rows() == 0) {
             *eos = true;
             return Status::OK();
         }
diff --git a/be/src/vec/functions/function_grouping.h 
b/be/src/vec/functions/function_grouping.h
index 23875af..f8ea725 100644
--- a/be/src/vec/functions/function_grouping.h
+++ b/be/src/vec/functions/function_grouping.h
@@ -37,6 +37,14 @@ public:
     DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) 
const override {
         return std::make_shared<DataTypeInt64>();
     }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) override {
+        const ColumnWithTypeAndName& src_column = 
block.get_by_position(arguments[0]);
+        DCHECK(src_column.column->size() == input_rows_count);
+        block.get_by_position(result).column = src_column.column;
+        return Status::OK();
+    }
 };
 
 class FunctionGrouping : public FunctionGroupingBase {
@@ -46,21 +54,6 @@ public:
     static FunctionPtr create() { return std::make_shared<FunctionGrouping>(); 
}
 
     String get_name() const override { return name; }
-
-    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
-                        size_t result, size_t input_rows_count) override {
-        const ColumnWithTypeAndName& src_column = 
block.get_by_position(arguments[0]);
-        const ColumnWithTypeAndName& rel_column = 
block.get_by_position(result);
-        if (!src_column.column)
-            return Status::InternalError("Illegal column " + src_column.name + 
" of first argument of function " + name);
-
-        DCHECK(src_column.type->is_nullable() == true);
-        MutableColumnPtr res_column = rel_column.type->create_column();
-        auto* src_nullable_column = reinterpret_cast<ColumnNullable 
*>(const_cast<IColumn *>(src_column.column.get()));
-        
res_column->insert_range_from(*src_nullable_column->get_nested_column_ptr().get(),
 0, src_column.column->size());
-        block.get_by_position(result).column = std::move(res_column);
-        return Status::OK();
-    }
 };
 
 class FunctionGroupingId : public FunctionGroupingBase {
@@ -70,21 +63,6 @@ public:
     static FunctionPtr create() { return 
std::make_shared<FunctionGroupingId>(); }
 
     String get_name() const override { return name; }
-
-    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
-                        size_t result, size_t input_rows_count) override {
-        const ColumnWithTypeAndName& src_column = 
block.get_by_position(arguments[0]);
-        const ColumnWithTypeAndName& rel_column = 
block.get_by_position(result);
-        if (!src_column.column)
-            return Status::InternalError("Illegal column " + src_column.name + 
" of first argument of function " + name);
-
-        DCHECK(src_column.type->is_nullable() == true);
-        MutableColumnPtr res_column = rel_column.type->create_column();
-        auto* src_nullable_column = reinterpret_cast<ColumnNullable 
*>(const_cast<IColumn *>(src_column.column.get()));
-        
res_column->insert_range_from(*src_nullable_column->get_nested_column_ptr().get(),
 0, src_column.column->size());
-        block.get_by_position(result).column = std::move(res_column);
-        return Status::OK();
-    }
 };
 }
 #endif //DORIS_FUNCTION_GROUPING_H
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java
index 04e3244..ed016ef 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java
@@ -125,10 +125,19 @@ public abstract class AggregateInfoBase {
         exprs.addAll(aggregateExprs_);
 
         int aggregateExprStartIndex = groupingExprs_.size();
+        // if agg is grouping set, so we should set all groupingExpr unless 
last groupingExpr
+        // must set be be nullable
+        boolean isGroupingSet = !groupingExprs_.isEmpty() &&
+                groupingExprs_.get(groupingExprs_.size() - 1) instanceof 
VirtualSlotRef;
+
         for (int i = 0; i < exprs.size(); ++i) {
             Expr expr = exprs.get(i);
             SlotDescriptor slotDesc = analyzer.addSlotDescriptor(result);
             slotDesc.initFromExpr(expr);
+            // Not change the nullable of slot desc when is not grouping set id
+            if (isGroupingSet && i < aggregateExprStartIndex - 1 && !(expr 
instanceof VirtualSlotRef)) {
+                slotDesc.setIsNullable(true);
+            }
             if (i < aggregateExprStartIndex) {
                 // register equivalence between grouping slot and grouping 
expr;
                 // do this only when the grouping expr isn't a constant, 
otherwise
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java
index b59c22c..95d3b45 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java
@@ -707,7 +707,8 @@ public class Analyzer {
 
     /**
      * Register a virtual column, and it is not a real column exist in table,
-     * so it does not need to resolve.
+     * so it does not need to resolve. now virtual slot: only use in grouping 
set to generate grouping id,
+     * so it should always is not nullable
      */
     public SlotDescriptor registerVirtualColumnRef(String colName, Type type, 
TupleDescriptor tupleDescriptor)
             throws AnalysisException {
@@ -722,7 +723,7 @@ public class Analyzer {
         result = addSlotDescriptor(tupleDescriptor);
         Column col = new Column(colName, type);
         result.setColumn(col);
-        result.setIsNullable(true);
+        result.setIsNullable(col.isAllowNull());
         slotRefMap.put(key, result);
         return result;
     }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/RepeatNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/RepeatNode.java
index ba5c232..b70b374 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/RepeatNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/RepeatNode.java
@@ -28,6 +28,7 @@ import org.apache.doris.analysis.SlotId;
 import org.apache.doris.analysis.SlotRef;
 import org.apache.doris.analysis.TupleDescriptor;
 import org.apache.doris.analysis.TupleId;
+import org.apache.doris.analysis.VirtualSlotRef;
 import org.apache.doris.common.UserException;
 import org.apache.doris.thrift.TExplainLevel;
 import org.apache.doris.thrift.TPlanNode;
@@ -132,7 +133,7 @@ public class RepeatNode extends PlanNode {
         outputTupleDesc = groupingInfo.getVirtualTuple();
         //set aggregate nullable
         for (Expr slot : groupByClause.getGroupingExprs()) {
-            if (slot instanceof SlotRef) {
+            if (slot instanceof SlotRef && !(slot instanceof VirtualSlotRef)) {
                 ((SlotRef) slot).getDesc().setIsNullable(true);
             }
         }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/planner/SingleNodePlanner.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/SingleNodePlanner.java
index e76253a..f62f606 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/SingleNodePlanner.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/SingleNodePlanner.java
@@ -1070,16 +1070,6 @@ public class SingleNodePlanner {
                 && groupingInfo != null);
         root = new RepeatNode(ctx_.getNextNodeId(), root, groupingInfo, 
groupByClause);
         root.init(analyzer);
-        // set agg outtuple nullable
-        AggregateInfo aggInfo = selectStmt.getAggInfo();
-        TupleId aggOutTupleId = aggInfo.getOutputTupleId();
-        TupleDescriptor aggOutTupleDescriptor = 
analyzer.getDescTbl().getTupleDesc(aggOutTupleId);
-        int aggregateExprStartIndex = groupByClause.getGroupingExprs().size();
-        for (int i = 0; i < aggregateExprStartIndex; ++i) {
-            SlotDescriptor slot = aggOutTupleDescriptor.getSlots().get(i);
-            if (!slot.getIsNullable())
-                slot.setIsNullable(true);
-        }
         return root;
     }
 

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to