Changeset: 8eb4729bbdb8 for MonetDB
URL: http://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=8eb4729bbdb8
Modified Files:
        monetdb5/optimizer/opt_candidates.c
        monetdb5/optimizer/opt_prelude.c
        monetdb5/optimizer/opt_prelude.h
        monetdb5/optimizer/opt_support.c
        sql/backends/monet5/rel_bin.c
        sql/backends/monet5/sql_statement.c
        sql/backends/monet5/sql_statement.h
        sql/server/rel_select.c
        sql/server/sql_parser.h
        sql/server/sql_parser.y
        sql/server/sql_scan.c
Branch: stratified_sampling
Log Message:

Add weighted sampling to the SQL layer.

The syntax for weighted sampling is:

SELECT * FROM <table> SAMPLE <id> WITH WEIGHTS <weight_column>;


diffs (232 lines):

diff --git a/monetdb5/optimizer/opt_candidates.c 
b/monetdb5/optimizer/opt_candidates.c
--- a/monetdb5/optimizer/opt_candidates.c
+++ b/monetdb5/optimizer/opt_candidates.c
@@ -64,7 +64,7 @@ OPTcandidatesImplementation(Client cntxt
                                setVarCList(mb,getArg(p,0));
                }
                else if (getModuleId(p) == sampleRef) {
-                       if (getFunctionId(p) == subuniformRef)
+                       if (getFunctionId(p) == subuniformRef || 
getFunctionId(p) == subweightedRef)
                                setVarCList(mb, getArg(p, 0));
                }
                else if (getModuleId(p) == groupRef && p->retc > 1) {
diff --git a/monetdb5/optimizer/opt_prelude.c b/monetdb5/optimizer/opt_prelude.c
--- a/monetdb5/optimizer/opt_prelude.c
+++ b/monetdb5/optimizer/opt_prelude.c
@@ -286,6 +286,7 @@ str not_uniqueRef;
 str sampleRef;
 str uniqueRef;
 str subuniformRef;
+str subweightedRef;
 str unlockRef;
 str unpackRef;
 str unpinRef;
@@ -567,6 +568,7 @@ void optimizerInit(void)
        not_uniqueRef= putName("not_unique");
        sampleRef= putName("sample");
        uniqueRef= putName("unique");
+       subweightedRef = putName("subweighted");
        subuniformRef= putName("subuniform");
        unlockRef= putName("unlock");
        unpackRef = putName("unpack");
diff --git a/monetdb5/optimizer/opt_prelude.h b/monetdb5/optimizer/opt_prelude.h
--- a/monetdb5/optimizer/opt_prelude.h
+++ b/monetdb5/optimizer/opt_prelude.h
@@ -280,6 +280,7 @@ mal_export  str not_uniqueRef;
 mal_export  str sampleRef;
 mal_export  str uniqueRef;
 mal_export  str subuniformRef;
+mal_export  str subweightedRef;
 mal_export  str unpackRef;
 mal_export  str unpinRef;
 mal_export  str unlockRef;
diff --git a/monetdb5/optimizer/opt_support.c b/monetdb5/optimizer/opt_support.c
--- a/monetdb5/optimizer/opt_support.c
+++ b/monetdb5/optimizer/opt_support.c
@@ -549,7 +549,9 @@ isSlice(InstrPtr p)
 int 
 isSample(InstrPtr p)
 {
-       return (getModuleId(p) == sampleRef && getFunctionId(p) == 
subuniformRef);
+       return (getModuleId(p) == sampleRef && 
+                       (getFunctionId(p) == subuniformRef || 
+                        getFunctionId(p) == subweightedRef));
 }
 
 int isOrderby(InstrPtr p){
diff --git a/sql/backends/monet5/rel_bin.c b/sql/backends/monet5/rel_bin.c
--- a/sql/backends/monet5/rel_bin.c
+++ b/sql/backends/monet5/rel_bin.c
@@ -530,9 +530,8 @@ exp_bin(backend *be, sql_exp *e, stmt *l
                        list *ops;
                        node *n;
                        int first = 1;
-
-                       ops = sa_list(sql->sa);
-                       args = e->l;
+                       ops = sa_list(sql->sa);
+                       args = e->l;
                        for( n = args->h; n; n = n->next ) {
                                s = NULL;
                                if (!swapped)
@@ -2794,12 +2793,23 @@ rel2bin_sample(backend *be, sql_rel *rel
                const char *tname = table_name(sql->sa, sc);
 
                s = exp_bin(be, rel->exps->h->data, NULL, NULL, NULL, NULL, 
NULL, NULL);
-
                if (!s)
                        s = stmt_atom_lng_nil(be);
 
                sc = column(be, sc);
-               sample = stmt_sample(be, stmt_alias(be, sc, tname, cname),s);
+
+               if (rel->exps->h->next) {
+                       stmt* left = rel_bin(be, rel->l);
+                       stmt* right = rel_bin(be, rel->r);
+                       // weighted sampling
+                       stmt* weights = exp_bin(be, rel->exps->h->next->data, 
left, right, NULL, NULL, NULL, NULL);
+                       if (!weights)
+                               return NULL;
+
+                       sample = stmt_weighted_sample(be, stmt_alias(be, sc, 
tname, cname), s, weights);
+               } else {
+                       sample = stmt_sample(be, stmt_alias(be, sc, tname, 
cname), s);
+               }
 
                for ( ; n; n = n->next) {
                        stmt *sc = n->data;
diff --git a/sql/backends/monet5/sql_statement.c 
b/sql/backends/monet5/sql_statement.c
--- a/sql/backends/monet5/sql_statement.c
+++ b/sql/backends/monet5/sql_statement.c
@@ -1013,6 +1013,34 @@ stmt_sample(backend *be, stmt *s, stmt *
        return NULL;
 }
 
+stmt *
+stmt_weighted_sample(backend *be, stmt *s, stmt *sample, stmt *weights)
+{
+       MalBlkPtr mb = be->mb;
+       InstrPtr q = NULL;
+       
+       if (s->nr < 0 || sample->nr < 0 || weights->nr < 0)
+               return NULL;
+
+       q = newStmt(mb, sampleRef, subweightedRef);
+       q = pushArgument(mb, q, s->nr);
+       q = pushArgument(mb, q, sample->nr);
+       q = pushArgument(mb, q, weights->nr);
+       if (q) {
+               stmt *ns = stmt_create(be->mvc->sa, st_sample);
+
+               ns->op1 = s;
+               ns->op2 = sample;
+               ns->nrcols = s->nrcols;
+               ns->key = s->key;
+               ns->aggr = s->aggr;
+               ns->flag = 0;
+               ns->q = q;
+               ns->nr = getDestVar(q);
+               return ns;
+       }
+       return NULL;
+}
 
 stmt *
 stmt_order(backend *be, stmt *s, int direction)
diff --git a/sql/backends/monet5/sql_statement.h 
b/sql/backends/monet5/sql_statement.h
--- a/sql/backends/monet5/sql_statement.h
+++ b/sql/backends/monet5/sql_statement.h
@@ -208,6 +208,7 @@ extern stmt *stmt_result(backend *be, st
  */ 
 extern stmt *stmt_limit(backend *sa, stmt *c, stmt *piv, stmt *gid, stmt 
*offset, stmt *limit, int distinct, int dir, int last, int order);
 extern stmt *stmt_sample(backend *be, stmt *s, stmt *sample);
+extern stmt *stmt_weighted_sample(backend *be, stmt *s, stmt *sample, stmt 
*weights);
 extern stmt *stmt_order(backend *be, stmt *s, int direction);
 extern stmt *stmt_reorder(backend *be, stmt *s, int direction, stmt 
*orderby_ids, stmt *orderby_grp);
 
diff --git a/sql/server/rel_select.c b/sql/server/rel_select.c
--- a/sql/server/rel_select.c
+++ b/sql/server/rel_select.c
@@ -4855,10 +4855,29 @@ rel_select_exp(mvc *sql, sql_rel *rel, S
 
        if (sn->sample) {
                list *exps = new_exp_list(sql->sa);
-               sql_exp *o = rel_value_exp( sql, NULL, sn->sample, 0, ek);
-               if (!o)
-                       return NULL;
-               append(exps, o);
+               if (sn->sample->token == SQL_WEIGHTED_SAMPLE) {
+                       // weighted sampling
+                       // parse the sample size and weight vector and pass it 
on to rel_sample
+                       dlist *l = sn->sample->data.lval;
+
+                       lng sample_size = l->h->data.l_val;
+                       sql_exp* sample_size_exp = exp_atom_lng(sql->sa, 
sample_size);
+
+                       symbol* weights = l->h->next->data.sym;
+                       sql_exp* weights_exp = rel_value_exp(sql, &rel, 
weights, 0, ek);
+
+                       if (!sample_size_exp || !weights_exp)
+                               return NULL;
+                       append(exps, sample_size_exp);
+                       append(exps, weights_exp);
+               } else {
+                       // uniform sampling
+                       // parse the sample size and pass it on to rel_sample
+                       sql_exp *o = rel_value_exp( sql, &rel, sn->sample, 0, 
ek);
+                       if (!o)
+                               return NULL;
+                       append(exps, o);
+               }
                rel = rel_sample(sql->sa, rel, exps);
        }
 
diff --git a/sql/server/sql_parser.h b/sql/server/sql_parser.h
--- a/sql/server/sql_parser.h
+++ b/sql/server/sql_parser.h
@@ -172,7 +172,8 @@ typedef enum tokens {
        SQL_XMLQUERY,
        SQL_XMLTEXT,
        SQL_XMLVALIDATE,
-       SQL_XMLNAMESPACES
+       SQL_XMLNAMESPACES,
+       SQL_WEIGHTED_SAMPLE
 } tokens;
 
 typedef enum jt {
diff --git a/sql/server/sql_parser.y b/sql/server/sql_parser.y
--- a/sql/server/sql_parser.y
+++ b/sql/server/sql_parser.y
@@ -602,7 +602,7 @@ SQLCODE SQLERROR UNDER WHENEVER
 %token ALTER ADD TABLE COLUMN TO UNIQUE VALUES VIEW WHERE WITH
 %token<sval> sqlDATE TIME TIMESTAMP INTERVAL
 %token YEAR MONTH DAY HOUR MINUTE SECOND ZONE
-%token LIMIT OFFSET SAMPLE
+%token LIMIT OFFSET SAMPLE WEIGHTS
 
 %token CASE WHEN THEN ELSE NULLIF COALESCE IF ELSEIF WHILE DO
 %token ATOMIC BEGIN END
@@ -3333,6 +3333,12 @@ opt_sample:
                          $$ = _newAtomNode( atom_float(SA, t, 
strtod($2,NULL)));
                        }
  |  SAMPLE param       { $$ = $2; }
+ | SAMPLE poslng WITH WEIGHTS search_condition { 
+       dlist *l = L();
+       append_lng(l, $2);
+       append_symbol(l, $5);
+       $$ = _symbol_create_list(SQL_WEIGHTED_SAMPLE, l);
+ }
  ;
 
 sort_specification_list:
diff --git a/sql/server/sql_scan.c b/sql/server/sql_scan.c
--- a/sql/server/sql_scan.c
+++ b/sql/server/sql_scan.c
@@ -208,6 +208,7 @@ scanner_init_keywords(void)
        keywords_insert("LIKE", LIKE);
        keywords_insert("LIMIT", LIMIT);
        keywords_insert("SAMPLE", SAMPLE);
+       keywords_insert("WEIGHTS", WEIGHTS);
        keywords_insert("LOCAL", LOCAL);
        keywords_insert("LOCKED", LOCKED);
        keywords_insert("NATURAL", NATURAL);
_______________________________________________
checkin-list mailing list
[email protected]
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to