This is an automated email from the ASF dual-hosted git repository.

maxyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/cloudberry.git

commit 609d876a5b843097d0fb97f1b0731f74dfca1a43
Author: Jingyu Wang <[email protected]>
AuthorDate: Wed Nov 30 14:59:02 2022 -0800

    Fix incorrect plan / output in multi stage agg
    
    Issue: Multi-stage DQA with ride along aggregation (1) failed to generate a
    plan, or, (2) output NaN/garbage results
    
    Offending query: select count(distinct a), sum(b) from foo;
    
    Root cause: AggSplit enumerator is used to describe aggregation stages. In
    Postgres, Agg node aggsplit is derived from Aggref aggsplit. It supports 
simple
    (non-split), initial, and final aggregation stage. In Greenplum's attempt to
    support multi-stage DQA, we added support for intermediate aggregation 
stage.
    Correspondingly, we removed the assertion of equality between Agg aggsplit 
and
    Aggref aggsplit.
    
    In the example above, count(distinct a) is a simple aggregation. The master
    handles aggregation, and the segments handle deduplication. sumb(b) is a 
split
    aggregation, where the master handles the final aggregation, and the 
segments
    handle initial and intermediate aggregations.
    
    Under the same Agg node, multiple Aggref's are allowed to have different
    aggsplit values, indicating different stage for different aggregations. 
However,
    we failed to consider that in deriving the Agg node aggsplit from its 
children
    Aggref aggsplit. We mistakenly assigned the first Aggref aggsplit value to 
its
    parent Agg node. To correct this mistake, we iterate through all children 
Aggref
    aggsplit, and use their bitwise OR result as the Agg node aggsplit.
    
    Example:
    Agg node has two children Aggref children
    aggref1 -> aggsplit = 0 (simple aggregation)
    aggref2 -> aggsplit = 9 (final aggregation)
    Agg -> aggsplit = 0 | 9 = 9 (final aggregation)
    
    Implementation:
    [node, CTranslatorDXLToScalar] -- Correct typo
    [CTranslatorDXLToPlStmt] -- Iterate through all children Aggrefs. Set Agg 
node
    aggsplit as the bitwise OR of Aggref aggsplit.
    [regress] -- Add test coverage for multi-stage DQA with ride along 
aggregation
    
    Note:
    This change tackles the issue from the ORCA side. It complements the work of
    PR #14577 "Fix crash of AggNode in executor casued by ORCA plan", which 
mends
    the problem from the executor side.
---
 src/backend/executor/nodeAgg.c                     |   2 +-
 .../gpopt/translate/CTranslatorDXLToPlStmt.cpp     |  29 ++---
 .../gpopt/translate/CTranslatorDXLToScalar.cpp     |   2 +-
 src/include/nodes/nodes.h                          |   2 +-
 src/test/regress/expected/gp_dqa.out               | 102 +++++++++++++++++
 src/test/regress/expected/gp_dqa_optimizer.out     | 124 ++++++++++++++++++++-
 src/test/regress/sql/gp_dqa.sql                    |  38 +++++++
 7 files changed, 275 insertions(+), 24 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 790821324a..f44186d72e 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -4233,7 +4233,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
         * transfn and transfn_oid fields of pertrans refer to the combine
         * function rather than the transition function.
         */
-       if (DO_AGGSPLIT_COMBINE(aggref->aggsplit))
+       if (DO_AGGSPLIT_COMBINE(pertrans->aggref->aggsplit))
        {
                Expr       *combinefnexpr;
                size_t          numTransArgs;
diff --git a/src/backend/gpopt/translate/CTranslatorDXLToPlStmt.cpp 
b/src/backend/gpopt/translate/CTranslatorDXLToPlStmt.cpp
index 8a40e924c0..c94e204021 100644
--- a/src/backend/gpopt/translate/CTranslatorDXLToPlStmt.cpp
+++ b/src/backend/gpopt/translate/CTranslatorDXLToPlStmt.cpp
@@ -2702,32 +2702,23 @@ CTranslatorDXLToPlStmt::TranslateDXLAgg(
         */
        // Set the aggsplit for the agg node
        ListCell *lc;
-       int idx = 0;
-       ForEach (lc, plan->targetlist)
+       INT aggsplit = 0;
+       foreach (lc, plan->targetlist)
        {
                TargetEntry *te = (TargetEntry *) lfirst(lc);
                if (IsA(te->expr, Aggref))
                {
                        Aggref *aggref = (Aggref *) te->expr;
-                       // initialize the aggsplit once
-                       if (idx == 0)
-                               agg->aggsplit = aggref->aggsplit;
-                       aggref->aggno = idx;
-                       aggref->aggtransno = idx;
-                       idx++;
-               }
-       }
-       ForEach (lc, plan->qual)
-       {
-               Expr *expr = (Expr *) lfirst(lc);
-               if (IsA(expr, Aggref))
-               {
-                       Aggref *aggref = (Aggref *) expr;
-                       aggref->aggno = idx;
-                       aggref->aggtransno = idx;
-                       idx++;
+
+                       aggsplit |= aggref->aggsplit;
+
+                       if (AGGSPLIT_INTERMEDIATE == aggsplit)
+                       {
+                               break;
+                       }
                }
        }
+       agg->aggsplit = (AggSplit) aggsplit;
 
        plan->lefttree = child_plan;
 
diff --git a/src/backend/gpopt/translate/CTranslatorDXLToScalar.cpp 
b/src/backend/gpopt/translate/CTranslatorDXLToScalar.cpp
index de79f591cf..0160180127 100644
--- a/src/backend/gpopt/translate/CTranslatorDXLToScalar.cpp
+++ b/src/backend/gpopt/translate/CTranslatorDXLToScalar.cpp
@@ -581,7 +581,7 @@ CTranslatorDXLToScalar::TranslateDXLScalarAggrefToScalar(
                        aggref->aggsplit = AGGSPLIT_INITIAL_SERIAL;
                        break;
                case EdxlaggstageIntermediate:
-                       aggref->aggsplit = AGGSPLIT_INTERNMEDIATE;
+                       aggref->aggsplit = AGGSPLIT_INTERMEDIATE;
                        break;
                case EdxlaggstageFinal:
                        aggref->aggsplit = AGGSPLIT_FINAL_DESERIAL;
diff --git a/src/include/nodes/nodes.h b/src/include/nodes/nodes.h
index a591ae5b1d..fa15896cd4 100644
--- a/src/include/nodes/nodes.h
+++ b/src/include/nodes/nodes.h
@@ -988,7 +988,7 @@ typedef enum AggSplit
         */
        AGGSPLIT_DEDUPLICATED = AGGSPLITOP_DEDUPLICATED,
 
-       AGGSPLIT_INTERNMEDIATE = AGGSPLITOP_SKIPFINAL | AGGSPLITOP_SERIALIZE | 
AGGSPLITOP_COMBINE | AGGSPLITOP_DESERIALIZE,
+       AGGSPLIT_INTERMEDIATE = AGGSPLITOP_SKIPFINAL | AGGSPLITOP_SERIALIZE | 
AGGSPLITOP_COMBINE | AGGSPLITOP_DESERIALIZE,
 } AggSplit;
 
 /* Test whether an AggSplit value selects each primitive option: */
diff --git a/src/test/regress/expected/gp_dqa.out 
b/src/test/regress/expected/gp_dqa.out
index 5ed20b2f3e..09c8aed08b 100644
--- a/src/test/regress/expected/gp_dqa.out
+++ b/src/test/regress/expected/gp_dqa.out
@@ -2473,6 +2473,108 @@ select count(distinct b), sum(c) from multiagg2;
 
 drop table multiagg1;
 drop table multiagg2;
+-- Support Multi-stage DQA with ride along aggregation in ORCA
+-- Historically, Agg aggsplit is identically equal to Aggref aggsplit
+-- In ORCA's attempt to support intermediate aggregation
+-- The two are allowed to differ
+-- Now Agg aggsplit is derived as bitwise OR of its children Aggref aggsplit
+-- The plan is to eventually make Agg aggsplit a dummy
+-- And use Aggref aggsplit to build trans/combine functions
+set optimizer_force_multistage_agg=on;
+create table num_table(id int, a bigint, b int, c numeric);
+NOTICE:  Table doesn't have 'DISTRIBUTED BY' clause -- Using column named 'id' 
as the Greenplum Database data distribution key for this table.
+HINT:  The 'DISTRIBUTED BY' clause determines the distribution of data. Make 
sure column(s) chosen are the optimal data distribution key to minimize skew.
+insert into num_table values(1,1,1,1),(2,2,2,2),(3,3,3,3);
+-- count(distinct a) is a simple aggregation
+-- sum(b) is a split aggregation
+-- Before the fix, in the final aggregation of sum(b)
+-- the executor mistakenly built a trans func instead of a combine func
+-- The trans func building process errored out due to mismatch between
+-- the input type (int) and trans type (bigint), and caused missing plan
+explain select count(distinct a), sum(b) from num_table;
+                                                 QUERY PLAN                    
                              
+-------------------------------------------------------------------------------------------------------------
+ Finalize Aggregate  (cost=14212.19..14212.20 rows=1 width=16)
+   ->  Gather Motion 3:1  (slice1; segments: 3)  (cost=14210.17..14212.18 
rows=3 width=16)
+         ->  Partial Aggregate  (cost=14210.17..14210.18 rows=1 width=16)
+               ->  Redistribute Motion 3:3  (slice2; segments: 3)  
(cost=0.00..14140.33 rows=13967 width=12)
+                     Hash Key: a
+                     ->  Seq Scan on num_table  (cost=0.00..173.67 rows=13967 
width=12)
+ Optimizer: Postgres query optimizer
+(7 rows)
+
+select count(distinct a), sum(b) from num_table;
+ count | sum 
+-------+-----
+     3 |   6
+(1 row)
+
+explain select count(distinct a), sum(b) from num_table group by id;
+                                             QUERY PLAN                        
                     
+----------------------------------------------------------------------------------------------------
+ Gather Motion 3:1  (slice1; segments: 3)  
(cost=10000001135.25..10000001944.92 rows=1000 width=20)
+   ->  GroupAggregate  (cost=10000001135.25..10000001278.25 rows=333 width=20)
+         Group Key: id
+         ->  Sort  (cost=1135.25..1170.17 rows=13967 width=16)
+               Sort Key: id
+               ->  Seq Scan on num_table  (cost=0.00..173.67 rows=13967 
width=16)
+ Optimizer: Postgres query optimizer
+(7 rows)
+
+select count(distinct a), sum(b) from num_table group by id;
+ count | sum 
+-------+-----
+     1 |   1
+     1 |   2
+     1 |   3
+(3 rows)
+
+-- count(distinct a) is a simple aggregation
+-- sum(c) is a split aggregation
+-- Before the fix, the final aggregation of sum(c) was mistakenly
+-- treated as simple aggregation, and led to the missing 
+-- deserialization step in the aggregation execution prep
+-- Numeric aggregation serializes partial aggregation states
+-- The executor then evaluated the aggregation state without deserializing it 
first
+-- This led to the creation of garbage NaN count, and caused NaN output
+explain select count(distinct a), sum(c) from num_table;
+                                                 QUERY PLAN                    
                              
+-------------------------------------------------------------------------------------------------------------
+ Finalize Aggregate  (cost=14212.20..14212.21 rows=1 width=40)
+   ->  Gather Motion 3:1  (slice1; segments: 3)  (cost=14210.17..14212.18 
rows=3 width=40)
+         ->  Partial Aggregate  (cost=14210.17..14210.18 rows=1 width=40)
+               ->  Redistribute Motion 3:3  (slice2; segments: 3)  
(cost=0.00..14140.33 rows=13967 width=40)
+                     Hash Key: a
+                     ->  Seq Scan on num_table  (cost=0.00..173.67 rows=13967 
width=40)
+ Optimizer: Postgres query optimizer
+(7 rows)
+
+select count(distinct a), sum(c) from num_table;
+ count | sum 
+-------+-----
+     3 |   6
+(1 row)
+
+explain select id, count(distinct a), avg(b), sum(c) from num_table group by 
grouping sets ((id,c));
+                                             QUERY PLAN                        
                      
+-----------------------------------------------------------------------------------------------------
+ Gather Motion 3:1  (slice1; segments: 3)  
(cost=10000001135.25..10000004159.03 rows=4190 width=108)
+   ->  GroupAggregate  (cost=10000001135.25..10000001365.70 rows=1397 
width=108)
+         Group Key: id, c
+         ->  Sort  (cost=1135.25..1170.17 rows=13967 width=48)
+               Sort Key: id, c
+               ->  Seq Scan on num_table  (cost=0.00..173.67 rows=13967 
width=48)
+ Optimizer: Postgres query optimizer
+(7 rows)
+
+select id, count(distinct a), avg(b), sum(c) from num_table group by grouping 
sets ((id,c));
+ id | count |          avg           | sum 
+----+-------+------------------------+-----
+  1 |     1 | 1.00000000000000000000 |   1
+  2 |     1 |     2.0000000000000000 |   2
+  3 |     1 |     3.0000000000000000 |   3
+(3 rows)
+
 reset optimizer_force_multistage_agg;
 reset optimizer_enable_use_distribution_in_dqa;
 drop table t_issue_659;
diff --git a/src/test/regress/expected/gp_dqa_optimizer.out 
b/src/test/regress/expected/gp_dqa_optimizer.out
index 787180b791..cb090292d3 100644
--- a/src/test/regress/expected/gp_dqa_optimizer.out
+++ b/src/test/regress/expected/gp_dqa_optimizer.out
@@ -2580,7 +2580,7 @@ analyze multiagg2;
 explain (verbose, costs off) select count(distinct b), sum(c) from multiagg1;
                                                QUERY PLAN                      
                         
 
--------------------------------------------------------------------------------------------------------
- Aggregate
+ Finalize Aggregate
    Output: count(b), sum(c)
    ->  Gather Motion 3:1  (slice1; segments: 3)
          Output: b, (PARTIAL sum(c))
@@ -2608,7 +2608,7 @@ select count(distinct b), sum(c) from multiagg1;
 explain (verbose, costs off) select count(distinct b), sum(c) from multiagg2;
                                                QUERY PLAN                      
                         
 
--------------------------------------------------------------------------------------------------------
- Aggregate
+ Finalize Aggregate
    Output: count(b), sum(c)
    ->  Gather Motion 3:1  (slice1; segments: 3)
          Output: b, (PARTIAL sum(c))
@@ -2635,6 +2635,126 @@ select count(distinct b), sum(c) from multiagg2;
 
 drop table multiagg1;
 drop table multiagg2;
+-- Support Multi-stage DQA with ride along aggregation in ORCA
+-- Historically, Agg aggsplit is identically equal to Aggref aggsplit
+-- In ORCA's attempt to support intermediate aggregation
+-- The two are allowed to differ
+-- Now Agg aggsplit is derived as bitwise OR of its children Aggref aggsplit
+-- The plan is to eventually make Agg aggsplit a dummy
+-- And use Aggref aggsplit to build trans/combine functions
+set optimizer_force_multistage_agg=on;
+create table num_table(id int, a bigint, b int, c numeric);
+NOTICE:  Table doesn't have 'DISTRIBUTED BY' clause -- Using column named 'id' 
as the Greenplum Database data distribution key for this table.
+HINT:  The 'DISTRIBUTED BY' clause determines the distribution of data. Make 
sure column(s) chosen are the optimal data distribution key to minimize skew.
+insert into num_table values(1,1,1,1),(2,2,2,2),(3,3,3,3);
+-- count(distinct a) is a simple aggregation
+-- sum(b) is a split aggregation
+-- Before the fix, in the final aggregation of sum(b)
+-- the executor mistakenly built a trans func instead of a combine func
+-- The trans func building process errored out due to mismatch between
+-- the input type (int) and trans type (bigint), and caused missing plan
+explain select count(distinct a), sum(b) from num_table;
+                                                 QUERY PLAN                    
                              
+-------------------------------------------------------------------------------------------------------------
+ Finalize Aggregate  (cost=0.00..431.00 rows=1 width=16)
+   ->  Gather Motion 3:1  (slice1; segments: 3)  (cost=0.00..431.00 rows=1 
width=16)
+         ->  Partial GroupAggregate  (cost=0.00..431.00 rows=1 width=16)
+               Group Key: a
+               ->  Sort  (cost=0.00..431.00 rows=1 width=16)
+                     Sort Key: a
+                     ->  Redistribute Motion 3:3  (slice2; segments: 3)  
(cost=0.00..431.00 rows=1 width=16)
+                           Hash Key: a
+                           ->  Partial GroupAggregate  (cost=0.00..431.00 
rows=1 width=16)
+                                 Group Key: a
+                                 ->  Sort  (cost=0.00..431.00 rows=1 width=12)
+                                       Sort Key: a
+                                       ->  Seq Scan on num_table  
(cost=0.00..431.00 rows=1 width=12)
+ Optimizer: Pivotal Optimizer (GPORCA)
+(14 rows)
+
+select count(distinct a), sum(b) from num_table;
+ count | sum 
+-------+-----
+     3 |   6
+(1 row)
+
+explain select count(distinct a), sum(b) from num_table group by id;
+                                     QUERY PLAN                                
     
+------------------------------------------------------------------------------------
+ Gather Motion 3:1  (slice1; segments: 3)  (cost=0.00..431.00 rows=1 width=16)
+   ->  Finalize GroupAggregate  (cost=0.00..431.00 rows=1 width=16)
+         Group Key: id
+         ->  Partial GroupAggregate  (cost=0.00..431.00 rows=1 width=20)
+               Group Key: id, a
+               ->  Sort  (cost=0.00..431.00 rows=1 width=16)
+                     Sort Key: id, a
+                     ->  Seq Scan on num_table  (cost=0.00..431.00 rows=1 
width=16)
+ Optimizer: Pivotal Optimizer (GPORCA)
+(9 rows)
+
+select count(distinct a), sum(b) from num_table group by id;
+ count | sum 
+-------+-----
+     1 |   1
+     1 |   2
+     1 |   3
+(3 rows)
+
+-- count(distinct a) is a simple aggregation
+-- sum(c) is a split aggregation
+-- Before the fix, the final aggregation of sum(c) was mistakenly
+-- treated as simple aggregation, and led to the missing 
+-- deserialization step in the aggregation execution prep
+-- Numeric aggregation serializes partial aggregation states
+-- The executor then evaluated the aggregation state without deserializing it 
first
+-- This led to the creation of garbage NaN count, and caused NaN output
+explain select count(distinct a), sum(c) from num_table;
+                                                 QUERY PLAN                    
                              
+-------------------------------------------------------------------------------------------------------------
+ Finalize Aggregate  (cost=0.00..431.00 rows=1 width=16)
+   ->  Gather Motion 3:1  (slice1; segments: 3)  (cost=0.00..431.00 rows=1 
width=16)
+         ->  Partial GroupAggregate  (cost=0.00..431.00 rows=1 width=16)
+               Group Key: a
+               ->  Sort  (cost=0.00..431.00 rows=1 width=16)
+                     Sort Key: a
+                     ->  Redistribute Motion 3:3  (slice2; segments: 3)  
(cost=0.00..431.00 rows=1 width=16)
+                           Hash Key: a
+                           ->  Partial GroupAggregate  (cost=0.00..431.00 
rows=1 width=16)
+                                 Group Key: a
+                                 ->  Sort  (cost=0.00..431.00 rows=1 width=16)
+                                       Sort Key: a
+                                       ->  Seq Scan on num_table  
(cost=0.00..431.00 rows=1 width=16)
+ Optimizer: Pivotal Optimizer (GPORCA)
+(14 rows)
+
+select count(distinct a), sum(c) from num_table;
+ count | sum 
+-------+-----
+     3 |   6
+(1 row)
+
+explain select id, count(distinct a), avg(b), sum(c) from num_table group by 
grouping sets ((id,c));
+                                     QUERY PLAN                                
     
+------------------------------------------------------------------------------------
+ Gather Motion 3:1  (slice1; segments: 3)  (cost=0.00..431.00 rows=1 width=28)
+   ->  Finalize GroupAggregate  (cost=0.00..431.00 rows=1 width=28)
+         Group Key: id, c
+         ->  Partial GroupAggregate  (cost=0.00..431.00 rows=1 width=36)
+               Group Key: id, c, a
+               ->  Sort  (cost=0.00..431.00 rows=1 width=24)
+                     Sort Key: id, c, a
+                     ->  Seq Scan on num_table  (cost=0.00..431.00 rows=1 
width=24)
+ Optimizer: Pivotal Optimizer (GPORCA)
+(9 rows)
+
+select id, count(distinct a), avg(b), sum(c) from num_table group by grouping 
sets ((id,c));
+ id | count |          avg           | sum 
+----+-------+------------------------+-----
+  1 |     1 | 1.00000000000000000000 |   1
+  2 |     1 |     2.0000000000000000 |   2
+  3 |     1 |     3.0000000000000000 |   3
+(3 rows)
+
 reset optimizer_force_multistage_agg;
 reset optimizer_enable_use_distribution_in_dqa;
 drop table t_issue_659;
diff --git a/src/test/regress/sql/gp_dqa.sql b/src/test/regress/sql/gp_dqa.sql
index 8b4cda3a8c..75a113c6f7 100644
--- a/src/test/regress/sql/gp_dqa.sql
+++ b/src/test/regress/sql/gp_dqa.sql
@@ -451,6 +451,44 @@ explain (verbose, costs off) select count(distinct b), 
sum(c) from multiagg2;
 select count(distinct b), sum(c) from multiagg2;
 drop table multiagg1;
 drop table multiagg2;
+
+-- Support Multi-stage DQA with ride along aggregation in ORCA
+-- Historically, Agg aggsplit is identically equal to Aggref aggsplit
+-- In ORCA's attempt to support intermediate aggregation
+-- The two are allowed to differ
+-- Now Agg aggsplit is derived as bitwise OR of its children Aggref aggsplit
+-- The plan is to eventually make Agg aggsplit a dummy
+-- And use Aggref aggsplit to build trans/combine functions
+set optimizer_force_multistage_agg=on;
+create table num_table(id int, a bigint, b int, c numeric);
+insert into num_table values(1,1,1,1),(2,2,2,2),(3,3,3,3);
+
+-- count(distinct a) is a simple aggregation
+-- sum(b) is a split aggregation
+-- Before the fix, in the final aggregation of sum(b)
+-- the executor mistakenly built a trans func instead of a combine func
+-- The trans func building process errored out due to mismatch between
+-- the input type (int) and trans type (bigint), and caused missing plan
+explain select count(distinct a), sum(b) from num_table;
+select count(distinct a), sum(b) from num_table;
+
+explain select count(distinct a), sum(b) from num_table group by id;
+select count(distinct a), sum(b) from num_table group by id;
+
+-- count(distinct a) is a simple aggregation
+-- sum(c) is a split aggregation
+-- Before the fix, the final aggregation of sum(c) was mistakenly
+-- treated as simple aggregation, and led to the missing 
+-- deserialization step in the aggregation execution prep
+-- Numeric aggregation serializes partial aggregation states
+-- The executor then evaluated the aggregation state without deserializing it 
first
+-- This led to the creation of garbage NaN count, and caused NaN output
+explain select count(distinct a), sum(c) from num_table;
+select count(distinct a), sum(c) from num_table;
+
+explain select id, count(distinct a), avg(b), sum(c) from num_table group by 
grouping sets ((id,c));
+select id, count(distinct a), avg(b), sum(c) from num_table group by grouping 
sets ((id,c));
+
 reset optimizer_force_multistage_agg;
 reset optimizer_enable_use_distribution_in_dqa;
 drop table t_issue_659;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to