Changeset: 9fe5afe1f73e for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/9fe5afe1f73e
Modified Files:
        sql/backends/monet5/rel_bin.c
        sql/test/cte/Tests/test_recursive_cte_union.test
Branch: recursive_cte
Log Message:

push topn down to the recursive level


diffs (290 lines):

diff --git a/sql/backends/monet5/rel_bin.c b/sql/backends/monet5/rel_bin.c
--- a/sql/backends/monet5/rel_bin.c
+++ b/sql/backends/monet5/rel_bin.c
@@ -4000,20 +4000,71 @@ subres_assign_resultvars(backend *be, st
        return stmt_list(be, nstmt);
 }
 
+static sql_exp*
+topn_limit(sql_rel *rel)
+{
+       if (rel->exps) {
+               sql_exp *limit = rel->exps->h->data;
+               if (exp_is_null(limit)) /* If the limit is NULL, ignore the 
value */
+                       return NULL;
+               return limit;
+       }
+       return NULL;
+}
+
+static sql_exp*
+topn_offset(sql_rel *rel)
+{
+       if (rel->exps && list_length(rel->exps) > 1) {
+               sql_exp *offset = rel->exps->h->next->data;
+
+               return offset;
+       }
+       return NULL;
+}
+
 static stmt *
-rel2bin_recursive_munion(backend *be, sql_rel *rel, list *refs)
+stmt_limit_value(backend *be, sql_rel *topn)
+{
+       stmt *l = NULL;
+
+       if (topn) {
+               sql_exp *le = topn_limit(topn);
+               sql_exp *oe = topn_offset(topn);
+
+               if (!le) { /* Don't push only offset */
+                       topn = NULL;
+               } else {
+                       l = exp_bin(be, le, NULL, NULL, NULL, NULL, NULL, NULL, 
0, 0, 0);
+                       if(!l)
+                               return NULL;
+                       if (oe) {
+                               sql_subtype *lng = sql_bind_localtype("lng");
+                               sql_subfunc *add = 
sql_bind_func_result(be->mvc, "sys", "sql_add", F_FUNC, true, lng, 2, lng, lng);
+                               stmt *o = exp_bin(be, oe, NULL, NULL, NULL, 
NULL, NULL, NULL, 0, 0, 0);
+                               if(!o)
+                                       return NULL;
+                               l = stmt_binop(be, l, o, NULL, add);
+                       }
+               }
+       }
+       return l;
+}
+
+static stmt *
+rel2bin_recursive_munion(backend *be, sql_rel *rel, list *refs, sql_rel *topn)
 {
        mvc *sql = be->mvc;
        stmt *rel_stmt = NULL, *sub;
        int nr_unions = list_length((list*)rel->l);
        if (nr_unions != 2)
                return sql_error(sql, 10, SQLSTATE(27000) "UNION: recursive 
unions need a base and recusive part");
+       stmt *l = stmt_limit_value(be, topn);
 
        bool distinct = need_distinct(rel);
        sql_rel *base = ((list*)rel->l)->h->data;
        sql_rel *recursive = ((list*)rel->l)->h->next->data;
 
-       /* TODO handle distinct */
        /* base part */
        rel_stmt = subrel_bin(be, base, refs);
        rel_stmt = subrel_project(be, rel_stmt, refs, base);
@@ -4021,6 +4072,7 @@ rel2bin_recursive_munion(backend *be, sq
                return NULL;
 
        if (recursive) {
+               int gcnt = 0;
                list *result_table = sa_list(be->mvc->sa);
                for(node *n = rel->exps->h; n; n = n->next) {
                        sql_exp *e = n->data;
@@ -4036,7 +4088,16 @@ rel2bin_recursive_munion(backend *be, sq
                sql_subfunc *cnt = sql_bind_func(sql, "sys", "count", 
sql_bind_localtype("void"), NULL, F_AGGR, true, true);
                stmt *cnts = stmt_aggr(be, rel_stmt->op4.lval->h->data, NULL, 
NULL, cnt, 1, 0, 1);
 
-               /* while cnt > 0: */
+               /* if topn keep total count */
+               if (l) {
+                       InstrPtr r = newAssignment(be->mb);
+                       gcnt = r->argv[0];
+                       r->argc = r->retc = 1;
+                       r = pushArgument(be->mb, r, cnts->nr);
+                       pushInstruction(be->mb, r);
+               }
+
+               /* while cnt > 0 (and total < limit): */
         InstrPtr r = newAssignment(be->mb);
         if (r == NULL)
                        return NULL;
@@ -4046,13 +4107,24 @@ rel2bin_recursive_munion(backend *be, sq
         r = pushBit(be->mb, r, TRUE);
                pushInstruction(be->mb, r);
 
-               r = newStmtArgs(be->mb, calcRef, "<=", 3);
+               if (l)
+                       r = newStmtArgs(be->mb, calcRef, "between", 9);
+               else
+                       r = newStmtArgs(be->mb, calcRef, "<=", 3);
                if (r == NULL)
                        return NULL;
         getArg(r, 0) = barrier_var;
                r->barrier = LEAVEsymbol;
                r = pushArgument(be->mb, r, cnts->nr);
                r = pushLng(be->mb, r, 0);
+               if (l) {
+                       r = pushArgument(be->mb, r, l->nr);
+                       r = pushBit(be->mb, r, FALSE); /* not symetrical */
+                       r = pushBit(be->mb, r, TRUE);  /* including lower bound 
*/
+                       r = pushBit(be->mb, r, FALSE); /* excluding upper bound 
*/
+                       r = pushBit(be->mb, r, FALSE); /* nils_false */
+                       r = pushBit(be->mb, r, TRUE);  /* anti */
+               }
                pushInstruction(be->mb, r);
 
                /* insert temptable into result_table and make link between 
result_table and table name */
@@ -4064,6 +4136,24 @@ rel2bin_recursive_munion(backend *be, sq
                        n->data = stmt_alias(be, r, a->label, a->tname, 
a->cname);
                }
 
+               if (l) {
+                       InstrPtr r = newStmtArgs(be->mb, calcRef, "+", 3);
+                       r->argv[0] = gcnt;
+                       r->argc = r->retc = 1;
+                       r = pushArgument(be->mb, r, cnts->nr);
+                       r = pushArgument(be->mb, r, gcnt);
+                       pushInstruction(be->mb, r);
+
+                       r = newStmtArgs(be->mb, calcRef, ">", 3);
+                       if (r == NULL)
+                               return NULL;
+                       getArg(r, 0) = barrier_var;
+                       r->barrier = LEAVEsymbol;
+                       r = pushArgument(be->mb, r, gcnt);
+                       r = pushArgument(be->mb, r, l->nr);
+                       pushInstruction(be->mb, r);
+               }
+
                /* recursive part */
                stmt *rec = subrel_bin(be, recursive, refs);
                if (!rec)
@@ -4129,7 +4219,7 @@ static stmt *
 rel2bin_munion(backend *be, sql_rel *rel, list *refs)
 {
        if (is_recursive(rel))
-               return rel2bin_recursive_munion(be, rel, refs);
+               return rel2bin_recursive_munion(be, rel, refs, NULL);
 
        mvc *sql = be->mvc;
        list *l, *rstmts;
@@ -4481,29 +4571,6 @@ sql_reorder(backend *be, stmt *order, li
        return stmt_list(be, l);
 }
 
-static sql_exp*
-topn_limit(sql_rel *rel)
-{
-       if (rel->exps) {
-               sql_exp *limit = rel->exps->h->data;
-               if (exp_is_null(limit)) /* If the limit is NULL, ignore the 
value */
-                       return NULL;
-               return limit;
-       }
-       return NULL;
-}
-
-static sql_exp*
-topn_offset(sql_rel *rel)
-{
-       if (rel->exps && list_length(rel->exps) > 1) {
-               sql_exp *offset = rel->exps->h->next->data;
-
-               return offset;
-       }
-       return NULL;
-}
-
 static stmt *
 rel2bin_project(backend *be, sql_rel *rel, list *refs, sql_rel *topn)
 {
@@ -4513,30 +4580,11 @@ rel2bin_project(backend *be, sql_rel *re
        stmt *sub = NULL, *psub = NULL;
        stmt *l = NULL;
 
-       if (topn) {
-               sql_exp *le = topn_limit(topn);
-               sql_exp *oe = topn_offset(topn);
-
-               if (!le) { /* Don't push only offset */
-                       topn = NULL;
-               } else {
-                       l = exp_bin(be, le, NULL, NULL, NULL, NULL, NULL, NULL, 
0, 0, 0);
-                       if(!l)
-                               return NULL;
-                       if (oe) {
-                               sql_subtype *lng = sql_bind_localtype("lng");
-                               sql_subfunc *add = sql_bind_func_result(sql, 
"sys", "sql_add", F_FUNC, true, lng, 2, lng, lng);
-                               stmt *o = exp_bin(be, oe, NULL, NULL, NULL, 
NULL, NULL, NULL, 0, 0, 0);
-                               if(!o)
-                                       return NULL;
-                               l = stmt_binop(be, l, o, NULL, add);
-                       }
-               }
-       }
-
        if (!rel->exps)
                return stmt_none(be);
 
+       l = stmt_limit_value(be, topn);
+
        if (rel->l) { /* first construct the sub relation */
                sql_rel *l = rel->l;
                if (l->op == op_ddl) {
@@ -4903,7 +4951,14 @@ rel2bin_topn(backend *be, sql_rel *rel, 
        if (rel->l) { /* first construct the sub relation */
                sql_rel *rl = rel->l;
 
-               if (rl->op == op_project) {
+               if (rl->op == op_munion && is_recursive(rl)) {
+                       if (rel_is_ref(rl)) {
+                               sub = refs_find_rel(refs, rl);
+                               if (!sub)
+                                       sub = rel2bin_recursive_munion(be, rl, 
refs, rel);
+                       } else
+                               sub = rel2bin_recursive_munion(be, rl, refs, 
rel);
+               } else if (rl->op == op_project) {
                        if (rel_is_ref(rl)) {
                                sub = refs_find_rel(refs, rl);
                                if (!sub)
@@ -4924,8 +4979,6 @@ rel2bin_topn(backend *be, sql_rel *rel, 
        n = sub->op4.lval->h;
        if (n) {
                stmt *limit = NULL, *sc = n->data;
-               //const char *cname = column_name(sql->sa, sc);
-               //const char *tname = table_name(sql->sa, sc);
                list *newl = sa_list(sql->sa);
                int oldvtop = be->mb->vtop, oldstop = be->mb->stop;
 
diff --git a/sql/test/cte/Tests/test_recursive_cte_union.test 
b/sql/test/cte/Tests/test_recursive_cte_union.test
--- a/sql/test/cte/Tests/test_recursive_cte_union.test
+++ b/sql/test/cte/Tests/test_recursive_cte_union.test
@@ -76,7 +76,6 @@ 2
 3
 
 # hash join
-skipif knownfail
 query I
 with recursive t as (select 1 as x union all select * from (select x from t 
where x < 5) tbl(i) join (select 1) tbl2(i) using (i)) select * from t limit 3;
 ----
@@ -84,7 +83,6 @@ 1
 1
 1
 
-skipif knownfail
 query I
 with recursive t as (select 1 as x union all select * from (select 1) tbl2(i) 
join (select x from t where x < 5) tbl(i) using (i)) select * from t limit 3;
 ----
@@ -92,17 +90,15 @@ 1
 1
 1
 
-skipif knownfail
 query I
-with recursive t as (select 1 as x union all select * from (select x from t 
where x < 5) tbl(i) join (select first(i) from (values (1)) tbl3(i) limit 1) 
tbl2(i) using (i)) select * from t limit 3;
+with recursive t as (select 1 as x union all select * from (select x from t 
where x < 5) tbl(i) join (select any_value(i) from (values (1)) tbl3(i) limit 
1) tbl2(i) using (i)) select * from t limit 3;
 ----
 1
 1
 1
 
-skipif knownfail
 query I
-with recursive t as (select 1 as x union all select * from (select first(i) 
from (values (1)) tbl3(i) limit 1) tbl2(i) join (select x from t where x < 5) 
tbl(i) using (i)) select * from t limit 3;
+with recursive t as (select 1 as x union all select * from (select 
any_value(i) from (values (1)) tbl3(i) limit 1) tbl2(i) join (select x from t 
where x < 5) tbl(i) using (i)) select * from t limit 3;
 ----
 1
 1
_______________________________________________
checkin-list mailing list -- [email protected]
To unsubscribe send an email to [email protected]

Reply via email to