Changeset: 96905a6ef18f for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/96905a6ef18f
Modified Files:
        sql/server/rel_optimizer.c
        sql/test/BugTracker-2017/Tests/side-effect.Bug-6397.test
        sql/test/SQLancer/Tests/sqlancer19.SQL.py
        sql/test/miscellaneous/Tests/simple_plans.test
Branch: default
Log Message:

Make sure rel_push_count_down returns the same output type


diffs (119 lines):

diff --git a/sql/server/rel_optimizer.c b/sql/server/rel_optimizer.c
--- a/sql/server/rel_optimizer.c
+++ b/sql/server/rel_optimizer.c
@@ -1594,51 +1594,49 @@ rel_push_count_down(visitor *v, sql_rel 
                r && !r->exps && r->op == op_join && !(rel_is_ref(r)) &&
                /* currently only single count aggregation is handled, no other 
projects or aggregation */
                list_length(rel->exps) == 1 && 
exp_aggr_is_count(rel->exps->h->data)) {
-               sql_exp *nce, *oce;
-               sql_rel *gbl, *gbr;             /* Group By */
-               sql_rel *cp;                    /* Cross Product */
-               sql_subfunc *mult;
-               list *args, *types;
+               sql_exp *nce, *oce, *cnt1 = NULL, *cnt2 = NULL;
+               sql_rel *gbl = NULL, *gbr = NULL;       /* Group By */
+               sql_rel *cp = NULL;                                     /* 
Cross Product */
                sql_rel *srel;
 
                oce = rel->exps->h->data;
                if (oce->l) /* we only handle COUNT(*) */
                        return rel;
 
-               args = new_exp_list(v->sql->sa);
                srel = r->l;
                {
                        sql_subfunc *cf = sql_bind_func(v->sql, "sys", "count", 
sql_bind_localtype("void"), NULL, F_AGGR);
-                       sql_exp *cnt, *e = exp_aggr(v->sql->sa, NULL, cf, 
need_distinct(oce), need_no_nil(oce), oce->card, 0);
+                       sql_exp *e = exp_aggr(v->sql->sa, NULL, cf, 
need_distinct(oce), need_no_nil(oce), oce->card, 0);
 
                        exp_label(v->sql->sa, e, ++v->sql->label);
-                       cnt = exp_ref(v->sql, e);
+                       cnt1 = exp_ref(v->sql, e);
                        gbl = rel_groupby(v->sql, rel_dup(srel), NULL);
                        set_processed(gbl);
                        rel_groupby_add_aggr(v->sql, gbl, e);
-                       append(args, cnt);
                }
 
                srel = r->r;
                {
                        sql_subfunc *cf = sql_bind_func(v->sql, "sys", "count", 
sql_bind_localtype("void"), NULL, F_AGGR);
-                       sql_exp *cnt, *e = exp_aggr(v->sql->sa, NULL, cf, 
need_distinct(oce), need_no_nil(oce), oce->card, 0);
+                       sql_exp *e = exp_aggr(v->sql->sa, NULL, cf, 
need_distinct(oce), need_no_nil(oce), oce->card, 0);
 
                        exp_label(v->sql->sa, e, ++v->sql->label);
-                       cnt = exp_ref(v->sql, e);
+                       cnt2 = exp_ref(v->sql, e);
                        gbr = rel_groupby(v->sql, rel_dup(srel), NULL);
                        set_processed(gbr);
                        rel_groupby_add_aggr(v->sql, gbr, e);
-                       append(args, cnt);
                }
 
                cp = rel_crossproduct(v->sql->sa, gbl, gbr, op_join);
 
-               types = sa_list(v->sql->sa);
-               for(node *n = args->h; n; n = n->next)
-                       list_append(types, exp_subtype(n->data));
-               mult = sql_bind_func_(v->sql, "sys", "sql_mul", types, F_FUNC);
-               nce = exp_op(v->sql->sa, args, mult);
+               if (!(nce = rel_binop_(v->sql, NULL, cnt1, cnt2, "sys", 
"sql_mul", card_value))) {
+                       v->sql->session->status = 0;
+                       v->sql->errstr[0] = '\0';
+                       return rel; /* error, fallback to original expression */
+               }
+               /* because of remote plans, make sure "sql_mul" returns bigint. 
The cardinality is atomic, so no major performance penalty */
+               if (subtype_cmp(exp_subtype(oce), exp_subtype(nce)) != 0)
+                       nce = exp_convert(v->sql->sa, nce, exp_subtype(nce), 
exp_subtype(oce));
                if (exp_name(oce))
                        exp_prop_alias(v->sql->sa, nce, oce);
 
diff --git a/sql/test/BugTracker-2017/Tests/side-effect.Bug-6397.test 
b/sql/test/BugTracker-2017/Tests/side-effect.Bug-6397.test
--- a/sql/test/BugTracker-2017/Tests/side-effect.Bug-6397.test
+++ b/sql/test/BugTracker-2017/Tests/side-effect.Bug-6397.test
@@ -32,6 +32,8 @@ bat.pack
 5
 bat.single
 2
+batcalc.lng
+1
 querylog.define
 1
 sql.my_generate_series
diff --git a/sql/test/SQLancer/Tests/sqlancer19.SQL.py 
b/sql/test/SQLancer/Tests/sqlancer19.SQL.py
--- a/sql/test/SQLancer/Tests/sqlancer19.SQL.py
+++ b/sql/test/SQLancer/Tests/sqlancer19.SQL.py
@@ -257,11 +257,21 @@ with SQLTestCase() as cli:
     cli.execute("SELECT 1 FROM v7 CROSS JOIN ((SELECT 1) UNION ALL (SELECT 2)) 
AS sub0(c0);") \
         
.assertSucceeded().assertDataResultMatch([(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,)])
     cli.execute("""
+        SELECT 1 FROM (VALUES (2),(3)) x(x) FULL OUTER JOIN (SELECT t1.c1 <= 
CAST(t1.c1 AS INT) FROM t1) AS sub0(c0) ON true WHERE sub0.c0
+        UNION ALL
+        SELECT 1 FROM (VALUES (2),(3)) x(x) FULL OUTER JOIN (SELECT t1.c1 <= 
CAST(t1.c1 AS INT) FROM t1) AS sub0(c0) ON true;
+        """).assertSucceeded() \
+        
.assertDataResultMatch([(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,)])
+    cli.execute("""
         SELECT 1 FROM (VALUES (2),(3)) x(x) FULL OUTER JOIN (SELECT rt1.c1 <= 
CAST(rt1.c1 AS INT) FROM rt1) AS sub0(c0) ON true WHERE sub0.c0
         UNION ALL
         SELECT 1 FROM (VALUES (2),(3)) x(x) FULL OUTER JOIN (SELECT rt1.c1 <= 
CAST(rt1.c1 AS INT) FROM rt1) AS sub0(c0) ON true;
         """).assertSucceeded() \
         
.assertDataResultMatch([(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,),(1,)])
+    cli.execute("SELECT count(0.3121149) FROM (select case when 2 > 1 then 0.3 
end from (select 1 from t3) x(x)) v100(vc1), t3 WHERE 5 >= sinh(CAST(v100.vc1 
AS REAL));") \
+        .assertSucceeded().assertDataResultMatch([(36,)])
+    cli.execute("SELECT count(0.3121149) FROM (select case when 2 > 1 then 0.3 
end from (select 1 from rt3) x(x)) v100(vc1), rt3 WHERE 5 >= sinh(CAST(v100.vc1 
AS REAL));") \
+        .assertSucceeded().assertDataResultMatch([(36,)])
     cli.execute("ROLLBACK;")
 
     cli.execute("CREATE FUNCTION mybooludf(a bool) RETURNS BOOL RETURN a;")
diff --git a/sql/test/miscellaneous/Tests/simple_plans.test 
b/sql/test/miscellaneous/Tests/simple_plans.test
--- a/sql/test/miscellaneous/Tests/simple_plans.test
+++ b/sql/test/miscellaneous/Tests/simple_plans.test
@@ -477,7 +477,7 @@ project (
 | | |  [ "sys"."cnt"(clob "sys", clob "another_t") NOT NULL as "%2"."%2" ],
 | | |  [ "sys"."cnt"(clob "sys", clob "another_t") NOT NULL as "%3"."%3" ]
 | | ) [  ]
-| ) [ "sys"."sql_mul"("%2"."%2" NOT NULL, "%3"."%3" NOT NULL) NOT NULL as 
"%1"."%1" ]
+| ) [ bigint["sys"."sql_mul"("%2"."%2" NOT NULL, "%3"."%3" NOT NULL) NOT NULL] 
NOT NULL as "%1"."%1" ]
 ) [ "%1"."%1" NOT NULL ]
 
 statement ok
_______________________________________________
checkin-list mailing list
checkin-list@monetdb.org
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to