Changeset: ac55e0235a4a for MonetDB
URL: http://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=ac55e0235a4a
Modified Files:
        sql/backends/monet5/rel_bin.c
        sql/backends/monet5/sql_gencode.c
        sql/backends/monet5/sql_statement.c
        sql/backends/monet5/sql_statement.h
        sql/server/rel_graph.c
        sql/server/sql_parser.y
Branch: graph1
Log Message:

Codegen: handle the case when an atom is the argument of CHEAPEST SUM ( )

TODO: it doesn't work for decimals. Values that are not numbers cause an
assert. They should be restricted in the semantic phase.


diffs (truncated from 334 to 300 lines):

diff --git a/sql/backends/monet5/rel_bin.c b/sql/backends/monet5/rel_bin.c
--- a/sql/backends/monet5/rel_bin.c
+++ b/sql/backends/monet5/rel_bin.c
@@ -23,6 +23,11 @@
 
 static stmt * subrel_bin(mvc *sql, sql_rel *rel, list *refs);
 
+// graph related functions
+static bool graph_is_spfw_output(const stmt *s);
+static bool graph_generate_shortest_path(stmt *s);
+static stmt *graph_exp_cost_atom(mvc *sql, stmt *spfw, stmt* const_weight);
+
 static stmt *
 refs_find_rel(list *refs, sql_rel *rel)
 {
@@ -615,8 +620,9 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
                        stmt *eperm = NULL; // the permutation for the sorted 
order of qfrom
                        sql_subaggr *aggr_count = sql_bind_aggr(sql->sa, 
sql->session->schema, "count", NULL);
                        stmt *domain =NULL, *query =NULL;
-                       stmt *spfw =NULL;
+                       stmt *spfw =NULL, *spfw_output =NULL;
                        stmt *weights = NULL;
+                       bool compute_shortest_path = false;
                        int spfw_flags = 0;
 
                        // generate the depending expressions
@@ -633,6 +639,7 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
                        eto = exp_bin(sql, g->dst->h->data, graph, NULL, NULL, 
NULL, NULL, NULL, refs);
                        if(!eto) { assert(0); return NULL; }
                        if(g->cost){
+                               compute_shortest_path = true; // yes, we want it
                                weights = exp_bin(sql, g->cost, graph, NULL, 
NULL, NULL, NULL, NULL, refs);
                                if(!weights) { assert(0); return NULL; }
                        }
@@ -656,7 +663,7 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
                        l = sa_list(sql->sa);
                        list_append(l, efrom);
                        list_append(l, eto);
-                       if(weights) { list_append(l, weights); }
+                       if(weights && weights->nrcols > 0) { list_append(l, 
weights); }
                        graph = stmt_list(sql->sa, l);
 
                        // generate the query
@@ -672,16 +679,28 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
 
                        // create the operator
                        if(right != NULL){ spfw_flags |= SPFW_CROSS_PRODUCT; }
-                       if(weights != NULL){ spfw_flags |= SPFW_SHORTEST_PATH; }
+                       if(compute_shortest_path){ spfw_flags |= 
SPFW_SHORTEST_PATH; }
                        spfw = stmt_spfw(sql->sa, query, graph, spfw_flags);
-                       if(weights != NULL){ // propagate the temp~ name for 
the column
+                       if(compute_shortest_path){ // propagate the temp~ name 
for the column
                                spfw->tname = g->cost->rname;
                                spfw->cname = g->cost->name;
                        }
 
-                       print_tree(sql->sa, spfw);
-
-                       return spfw;
+                       // if the user input is a constant, then let spfw 
perform a BFS and multiply
+                       // the constant at the end with the result
+                       if(weights && weights->nrcols == 0){ // this is 
actually a constant
+                               stmt* cost = graph_exp_cost_atom(sql, spfw, 
weights);
+                               if(!cost){
+                                       assert(0 && "Unable to generate the 
shortest path out of a constant");
+                                       return NULL;
+                               }
+                               // final output
+                               spfw_output = stmt_spfw_output(sql->sa, spfw, 
cost);
+                       } else {
+                               spfw_output = spfw; // done
+                       }
+
+                       return spfw_output;
                }
                if (e->flag == cmp_in || e->flag == cmp_notin) {
                        return handle_in_exps(sql, e->l, e->r, left, right, 
grp, ext, cnt, sel, (e->flag == cmp_in), 0);
@@ -846,6 +865,88 @@ exp_bin(mvc *sql, sql_exp *e, stmt *left
        return s;
 }
 
+// True iff the given statement is spfw or the output of the spfw
+static bool graph_is_spfw_output(const stmt *s){
+       assert(s != NULL && "is_spfw_output(): null argument");
+       return (s->type == st_spfw) || (s->type == st_spfw_output);
+}
+
+
+// True iff the statement is an spfw output AND it outputs a shortest path 
column
+static bool graph_generate_shortest_path(stmt *s){
+       return graph_is_spfw_output(s) && (s->flag & SPFW_SHORTEST_PATH);
+}
+
+// When the user provides an atom as argument of CHEAPEST SUM ( .. ), the 
weight
+// is not forwarded to the SPFW operator. Instead the SPFW computes a BFS and
+// the input provided by the user is multiplied afterward with the
+// result of the BFS.
+static stmt *graph_exp_cost_atom(mvc *sql, stmt *spfw, stmt* const_weight){
+       sql_subfunc* fmult = NULL;
+       sql_subtype shortest_path_default_type;
+       stmt* bfs_output = NULL;
+       stmt* mult = NULL;
+       list* l = NULL;
+
+       assert(const_weight && "Null argument");
+       assert(const_weight->nrcols == 0 && "Expected an atom as input");
+
+       // get a reference to the cost computed by the spfw
+       bfs_output = stmt_result(sql->sa, spfw, 2);
+
+       // get a reference to the type lng
+       if(!sql_find_subtype(&shortest_path_default_type, "bigint", 64, 0)){
+               assert(0 && "Unable to bind type lng");
+               return NULL;
+       }
+
+       // decide the type of the final expression
+       switch(const_weight->op4.typeval.type->eclass){
+       case EC_NUM: { // integers
+               // if it is an integer, convert to lng, as this is the result 
of the BFS
+               const_weight = stmt_convert(sql->sa, const_weight, 
&(const_weight->op4.typeval), &shortest_path_default_type);
+               if(!const_weight){
+                       assert(0 && "Unable to convert the type for the 
constant to lng");
+                       return NULL;
+               }
+       } break;
+       case EC_DEC: { // float or double
+               // FIXME this always ends up in overflow, multiplying the 
result of the BFS by 10 ^ (number of digits in the atom)
+               // convert the result of the bfs to decimal
+               bfs_output = stmt_convert(sql->sa, bfs_output, 
&shortest_path_default_type, &(const_weight->op4.typeval));
+               if(!bfs_output){
+                       assert(0 && "Unable to convert the result of spfw/bfs 
to decimal type");
+                       return NULL;
+               }
+       } break;
+       default:
+               assert(0 && "Type not handled");
+               return NULL;
+       }
+
+       // transform the atom into a column
+       const_weight = const_column(sql->sa, const_weight);
+       if(!const_weight){
+               assert(0 && "Unable to materialize the column weights with a 
constant");
+               return NULL;
+       }
+
+       // get a reference to the function multiply
+       fmult = sql_bind_func(sql->sa, sql->session->schema, "sql_mul", 
&(const_weight->op4.typeval), &(const_weight->op4.typeval), F_FUNC);
+       if(!fmult){
+               assert(0 && "Unable to bind the function sql_mult (*)");
+               return NULL;
+       }
+
+       // perform the multiplication
+       l = sa_list(sql->sa);
+       list_append(l, const_weight);
+       list_append(l, bfs_output);
+       mult = stmt_Nop(sql->sa, stmt_list(sql->sa, l), fmult);
+
+       return mult;
+}
+
 static stmt *check_types(mvc *sql, sql_subtype *ct, stmt *s, check_type tpe);
 
 static stmt *
@@ -1809,7 +1910,7 @@ rel2bin_join( mvc *sql, sql_rel *rel, li
                        if (s->type != st_join && 
                            s->type != st_join2 &&
                            s->type != st_joinN &&
-                           s->type != st_spfw) {
+                           !graph_is_spfw_output(s)) { // spfw can handle 
joins directly
                                /* predicate */
                                if (!list_length(lje) && s->nrcols == 0) { 
                                        stmt *l = bin_first_column(sql->sa, 
left);
@@ -1832,7 +1933,7 @@ rel2bin_join( mvc *sql, sql_rel *rel, li
                        list_append(lje, s->op1);
                        list_append(rje, s->op2);
 
-                       if(s->type == st_spfw && (s->flag & 
SPFW_SHORTEST_PATH)){
+                       if(graph_generate_shortest_path(s)){
                                list_append(shoooortestpaths, s);
                        }
                }
@@ -1887,7 +1988,7 @@ rel2bin_join( mvc *sql, sql_rel *rel, li
                                return NULL;
                        }
 
-                       if ( s->type == st_spfw && (s->flag & 
SPFW_SHORTEST_PATH) ){
+                       if (graph_generate_shortest_path(s)){
                                list_append(shoooortestpaths, s);
                        }
 
@@ -2661,7 +2762,7 @@ rel2bin_select( mvc *sql, sql_rel *rel, 
                        sel = s;
                }
 
-               if(s->type == st_spfw && (s->flag & SPFW_SHORTEST_PATH)){
+               if(graph_generate_shortest_path(s)){
                        list_append(shooortestpaths, s);
                }
        }
diff --git a/sql/backends/monet5/sql_gencode.c 
b/sql/backends/monet5/sql_gencode.c
--- a/sql/backends/monet5/sql_gencode.c
+++ b/sql/backends/monet5/sql_gencode.c
@@ -2939,6 +2939,30 @@ static int
                        renameVariable(mb, getArg(q, 1), "r1_%d", s->nr); // 
filter dst
                        renameVariable(mb, getArg(q, 2), "r2_%d", s->nr); // 
shortest path (if required)
                } break;
+               case st_spfw_output: {
+                       InstrPtr p = NULL;
+
+                       // generate spfw
+                       if(_dumpstmt(sql, mb, s->op1) < 0)
+                               return -1;
+
+                       // generate the shortest path
+                       if(_dumpstmt(sql, mb, s->op2) < 0)
+                               return -1;
+
+                       p = newAssignment(mb);
+                       p = pushArgument(mb, p, _dumpstmt(sql, mb, 
stmt_result(sql->mvc->sa, s->op1, 0)));
+                       s->nr = getDestVar(p); // forward the result of the join
+                       q = p; // result
+
+                       p = newAssignment(mb);
+                       p = pushArgument(mb, p, _dumpstmt(sql, mb, 
stmt_result(sql->mvc->sa, s->op1, 1)));
+                       renameVariable(mb, getDestVar(p), "r1_%d", s->nr); // 
as above
+
+                       p = newAssignment(mb);
+                       p = pushArgument(mb, p, s->op2->nr);
+                       renameVariable(mb, getDestVar(p), "r2_%d", s->nr); // 
shortest path
+               } break;
                case st_void2oid: {
                        int ref_op = _dumpstmt(sql, mb, s->op1);
                        if(ref_op < 0)
diff --git a/sql/backends/monet5/sql_statement.c 
b/sql/backends/monet5/sql_statement.c
--- a/sql/backends/monet5/sql_statement.c
+++ b/sql/backends/monet5/sql_statement.c
@@ -127,6 +127,7 @@ st_type2string(st_type type)
                ST(prefixsum);
                ST(slices);
                ST(spfw);
+               ST(spfw_output);
                ST(void2oid);
        default:
                return "unknown";       /* just needed for broken compilers ! */
@@ -335,6 +336,7 @@ stmt_deps(list *dep_list, stmt *s, int d
                        case st_prefixsum:
                        case st_slices:
                        case st_spfw:
+                       case st_spfw_output:
                        case st_void2oid:
                                if (s->op1)
                                        push(s->op1);
@@ -1682,6 +1684,20 @@ stmt_spfw(sql_allocator *sa, stmt* query
 }
 
 stmt *
+stmt_spfw_output(sql_allocator *sa, stmt* spfw, stmt* shortest_path)
+{
+       stmt * s = stmt_create(sa, st_spfw_output);
+       s->op1 = spfw;
+       s->op2 = shortest_path;
+       s->nrcols = spfw->nrcols;
+       s->flag = spfw->flag;
+       s->tname = spfw->tname;
+       s->cname = spfw->cname;
+
+       return s;
+}
+
+stmt *
 stmt_void2oid(sql_allocator *sa, stmt *op){
        stmt *s = stmt_create(sa, st_void2oid);
        s->op1 = op;
diff --git a/sql/backends/monet5/sql_statement.h 
b/sql/backends/monet5/sql_statement.h
--- a/sql/backends/monet5/sql_statement.h
+++ b/sql/backends/monet5/sql_statement.h
@@ -103,6 +103,7 @@ typedef enum stmt_type {
        st_prefixsum,
        st_slices,
        st_spfw,
+       st_spfw_output,
        st_void2oid,
 } st_type;
 
@@ -111,6 +112,7 @@ typedef enum stmt_type {
 #define ANTI ANTISEL
 #define GRP_DONE 32
 
+/* flags for st_spfw */
 #define SPFW_CROSS_PRODUCT 0x1 /* Perform a cross product instead of a filter 
op~ */
 #define SPFW_SHORTEST_PATH 0x2 /* Retrieve also the cost of the path */
 
@@ -258,9 +260,9 @@ extern stmt *stmt_mkpartition(sql_alloca
 extern stmt *stmt_prefixsum(sql_allocator *sa, stmt *op, stmt 
*domain_cardinality);
 extern stmt *stmt_slices(sql_allocator *sa, stmt *op, int num);
 extern stmt *stmt_spfw(sql_allocator *sa, stmt *query, stmt *graph, int flags);
+extern stmt *stmt_spfw_output(sql_allocator *sa, stmt* spfw, stmt* 
shortest_path);
_______________________________________________
checkin-list mailing list
checkin-list@monetdb.org
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to