Changeset: 2ad57d092fcd for MonetDB
URL: http://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=2ad57d092fcd
Modified Files:
        gdk/gdk.h
        gdk/gdk_sample.c
        monetdb5/modules/mal/sample.c
        monetdb5/modules/mal/sample.h
        monetdb5/modules/mal/sample.mal
        sql/backends/monet5/Makefile.ag
Branch: stratified_sampling
Log Message:

implement weighted sampling UDF


diffs (truncated from 386 to 300 lines):

diff --git a/gdk/gdk.h b/gdk/gdk.h
--- a/gdk/gdk.h
+++ b/gdk/gdk.h
@@ -3026,12 +3026,17 @@ gdk_export gdk_return BATfirstn(BAT **to
  * @tab BATsample (BAT *b, n)
  * @end multitable
  *
- * The routine BATsample returns a random sample on n BUNs of a BAT.
+ * The routine BATsample returns a random sample containing n BUNs of a BAT.
  *
  */
 gdk_export BAT *BATsample(BAT *b, BUN n);
 
 /*
+ * The routine BATweightedsample returns a weighted random sample containing n 
BUNs of a BAT.
+ */
+gdk_export BAT *BATweightedsample(BAT *b, BUN n, BAT *w);
+
+/*
  *
  */
 #define MAXPARAMS      32
diff --git a/gdk/gdk_sample.c b/gdk/gdk_sample.c
--- a/gdk/gdk_sample.c
+++ b/gdk/gdk_sample.c
@@ -31,6 +31,7 @@
 
 #undef BATsample
 
+
 /* this is a straightforward implementation of a binary tree */
 struct oidtreenode {
        oid o;
@@ -81,21 +82,31 @@ OIDTreeToBATAntiset(struct oidtreenode *
        oid noid;
 
        if (node->left != NULL)
-                       OIDTreeToBATAntiset(node->left, bat, start, node->o);
+               OIDTreeToBATAntiset(node->left, bat, start, node->o);
        else
                for (noid = start; noid < node->o; noid++)
                        ((oid *) bat->T->heap.base)[bat->batFirst + 
bat->batCount++] = noid;
 
                if (node->right != NULL)
-               OIDTreeToBATAntiset(node->right, bat, node->o + 1, stop);
+                       OIDTreeToBATAntiset(node->right, bat, node->o + 1, 
stop);
        else
-               for (noid = node->o+1; noid < stop; noid++)
-                                               ((oid *) 
bat->T->heap.base)[bat->batFirst + bat->batCount++] = noid;
+               for (noid = node->o + 1; noid < stop; noid++)
+                       ((oid *) bat->T->heap.base)[bat->batFirst + 
bat->batCount++] = noid;
 }
 
-/* BATsample implements sampling for void headed BATs */
-BAT *
-BATsample(BAT *b, BUN n)
+/* inorder traversal, gives us a bit BAT */
+/*BAT *bat OIDTreeToBITBAT(struct oidtreenode)
+{
+       //TODO create this function
+}*/
+
+/* 
+ * _BATsample is the internal (weighted) sampling function without replacement
+ * If cdf=NULL, an uniform sample is taken
+ * Otherwise it is assumed the cdf increases monotonically
+ */
+static BAT *
+_BATsample(BAT *b, BUN n, BAT *cdf)
 {
        BAT *bn;
        BUN cnt, slen;
@@ -103,6 +114,9 @@ BATsample(BAT *b, BUN n)
        struct oidtreenode *tree = NULL;
        mtwist *mt_rng;
        unsigned int range;
+       dbl random;
+       dbl cdf_max;
+       dbl* cdf_ptr;
 
        BATcheck(b, "BATsample", NULL);
        assert(BAThdense(b));
@@ -148,7 +162,6 @@ BATsample(BAT *b, BUN n)
                        GDKfree(tree);
                        return NULL;
                }
-               /* while we do not have enough sample OIDs yet */
                
                /* create and seed Mersenne Twister */
                mt_rng = mtwist_new();
@@ -157,16 +170,49 @@ BATsample(BAT *b, BUN n)
                
                range = maxoid - minoid;
                
-               for (rescnt = 0; rescnt < n; rescnt++) {
-                       oid candoid;
-                       do {
-                               /* generate a new random OID in [minoid, maxoid[
-                                * that is including minoid, excluding maxoid*/
-                               candoid = (oid) ( minoid + 
(mtwist_u32rand(mt_rng)%range) );
-                               /* if that candidate OID was already
-                                * generated, try again */
-                       } while (!OIDTreeMaybeInsert(tree, candoid, rescnt));
+               /* sample OIDs (method depends on w) */
+               if(cdf == NULL) {
+                       /* no weights, hence do uniform sampling */
+
+                       /* while we do not have enough sample OIDs yet */
+                       for (rescnt = 0; rescnt < n; rescnt++) {
+                               oid candoid;
+                               do {
+                                       /* generate a new random OID in 
[minoid, maxoid[
+                                        * that is including minoid, excluding 
maxoid*/
+                                       candoid = (oid) ( minoid + 
(mtwist_u32rand(mt_rng)%range) );
+                                       /* if that candidate OID was already
+                                        * generated, try again */
+                               } while (!OIDTreeMaybeInsert(tree, candoid, 
rescnt));
+                       }
+
+               } else {
+                       /* do weighted sampling */
+                       
+                       cdf_ptr = (dbl*) Tloc(cdf, BUNfirst(cdf));
+                       if (!antiset)
+                               cdf_max = cdf_ptr[cnt-1];
+                       else
+                               cdf_max = cdf_ptr[0];
+                          //TODO how to type/cast cdf_max?
+
+                       /* generate candoids, using CDF */
+                       for (rescnt = 0; rescnt < n; rescnt++) {
+                               oid candoid;
+
+                               do {
+                                       random = mtwist_drand(mt_rng)*cdf_max;
+                                       /* generate a new random OID in 
[minoid, maxoid[
+                                        * that is including minoid, excluding 
maxoid*/
+                                       /* note that cdf has already been 
adjusted for antiset case */
+                                       candoid = (oid) ( minoid + (oid) 
SORTfndfirst(cdf, &random) );
+                                       /* if that candidate OID was already
+                                        * generated, try again */
+                               } while (!OIDTreeMaybeInsert(tree, candoid, 
rescnt));
+                       }
                }
+
+
                if (!antiset) {
                        OIDTreeToBAT(tree, bn);
                } else {
@@ -189,3 +235,96 @@ BATsample(BAT *b, BUN n)
        }
        return bn;
 }
+
+
+/* BATsample takes uniform samples of void headed BATs */
+BAT *
+BATsample(BAT *b, BUN n)
+{
+       return _BATsample(b, n, NULL);
+}
+
+/* BATweightedsample takes weighted samples of void headed BATs */
+/* Note that the type of w should be castable to doubles */
+BAT *
+BATweightedsample(BAT *b, BUN n, BAT *w)
+{
+       BAT* cdf;
+       BAT* sample;
+       dbl* w_ptr;//TODO types of w
+       dbl* cdf_ptr;
+       BUN cnt, i;
+       bit antiset;
+
+       BATcheck(b, "BATsample", NULL);
+       BATcheck(w, "BATsample", NULL);
+
+       ERRORcheck(w->ttype == TYPE_str || w->ttype == TYPE_void,
+                                       "BATsample: type of weights not 
castable to doubles\n", NULL);
+       ERRORcheck(w->ttype != TYPE_dbl,
+                                       "BATsample: type of weights must be 
doubles\n", NULL);//TODO types of w (want to remove this)
+
+       cnt = BATcount(b);
+
+       antiset = n > cnt / 2;
+
+       cdf = BATnew(TYPE_void, TYPE_dbl, cnt, TRANSIENT);
+       BATsetcount(cdf, cnt);
+       
+       /* calculate cumilative distribution function */
+       w_ptr = (dbl*) Tloc(w, BUNfirst(w));//TODO support different types w
+       cdf_ptr = (dbl*) Tloc(cdf, BUNfirst(cdf));
+
+       cdf_ptr[0] = (dbl)w_ptr[0];
+       for (i = 1; i < cnt; i++) {
+               if((dbl)w_ptr[i] == dbl_nil) {//TODO fix NULL-test if w can 
have different types
+                       cdf_ptr[i] = cdf_ptr[i-1];
+               } else {
+                       cdf_ptr[i] = ((dbl)w_ptr[i]) + cdf_ptr[i-1];
+               }
+       }
+       if (!antiset) {
+               cdf->tsorted = 1;
+               cdf->trevsorted = cnt <= 1;
+       } else {
+               /* in antiset notation, we have to flip probabilities */
+               for (i = 0; i < cnt; i++) {
+                        cdf_ptr[i] = cdf_ptr[cnt-1] - cdf_ptr[i];
+               }
+               cdf->tsorted = cnt <= 1;
+               cdf->trevsorted = 1;
+       }
+       
+       /* obtain sample */
+       sample = _BATsample(b, n, cdf);
+       
+       BATdelete(cdf);
+
+       return sample;
+}
+
+
+/* BATweightedbitbat creates a bit BAT of length cnt containing n 1s and cnt-n 
0s */
+/* Note that the type of w should be castable to doubles */
+/*BAT *
+BATweightedbitbat(BUN cnt, BUN n, BAT *w)
+{
+       BAT* res;
+       res = BATnew(TYPE_void, TYPE_dbl, cnt, TRANSIENT);
+       BATsetcount(res, cnt);
+       
+       //Need to adjust _BATsample so it will return a bit BAT with bools 
denoting if element is selected
+       //Now it will rather return a subset
+       //TODO rewrite _BATsample to support this, add call to _BATsample
+       //Why did we choose for this UDF notation?
+       //+ easier to implement (no parsing addition)
+       //- slow
+       //- actually yields uglier code
+       //Why implement something like that? Hence we should choose for the 
other notation?
+       
+       
+       return res;
+}
+*/
+
+
diff --git a/monetdb5/modules/mal/sample.c b/monetdb5/modules/mal/sample.c
--- a/monetdb5/modules/mal/sample.c
+++ b/monetdb5/modules/mal/sample.c
@@ -54,10 +54,10 @@
  * CREATE FUNCTION mysample ()
  * RETURNS TABLE(col a,...)
  * BEGIN
- *    RETURN
- *      SELECT a,...
- *      FROM name_table
- *      SAMPLE 100;
+ *     RETURN
+ *       SELECT a,...
+ *       FROM name_table
+ *       SAMPLE 100;
  * end;
  *
  * and then use function mysample() for example to populate a new table with
@@ -104,3 +104,43 @@ SAMPLEuniform_dbl(bat *r, bat *b, dbl *p
        BBPunfix(bb->batCacheid);
        return SAMPLEuniform(r, b, &s);
 }
+
+str
+SAMPLEweighted(bat *r, bat *b, wrd *s, bat *w) {
+       BAT *br, *bb, *bw;
+
+       if ((bw = BATdescriptor(*w)) == NULL ) {
+               throw(MAL, "sample.subweighted", INTERNAL_BAT_ACCESS);
+       }
+       if ((bb = BATdescriptor(*b)) == NULL ) {
+               throw(MAL, "sample.subweighted", INTERNAL_BAT_ACCESS);
+       }
+       br = BATweightedsample(bb, (BUN) *s, bw);
+       if (br == NULL)
+               throw(MAL, "sample.subweighted", OPERATION_FAILED);
+
+       BBPunfix(bb->batCacheid);
+       BBPkeepref(*r = br->batCacheid);
+       return MAL_SUCCEED;
+}
+
+str
+SAMPLEweighted_dbl(bat *r, bat *b, dbl *p, bat *w) {
+       BAT *bb;
+       double pr = *p;
+       wrd s;
+
+       if ( pr < 0.0 || pr > 1.0 ) {
+               throw(MAL, "sample.subweighted", ILLEGAL_ARGUMENT
+                               " p should be between 0 and 1.0" );
+       } else if (pr == 0) {/* special case */
+               s = 0;
+               return SAMPLEweighted(r, b, &s, w);
_______________________________________________
checkin-list mailing list
[email protected]
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to