Changeset: 2ad57d092fcd for MonetDB
URL: http://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=2ad57d092fcd
Modified Files:
gdk/gdk.h
gdk/gdk_sample.c
monetdb5/modules/mal/sample.c
monetdb5/modules/mal/sample.h
monetdb5/modules/mal/sample.mal
sql/backends/monet5/Makefile.ag
Branch: stratified_sampling
Log Message:
implement weighted sampling UDF
diffs (truncated from 386 to 300 lines):
diff --git a/gdk/gdk.h b/gdk/gdk.h
--- a/gdk/gdk.h
+++ b/gdk/gdk.h
@@ -3026,12 +3026,17 @@ gdk_export gdk_return BATfirstn(BAT **to
* @tab BATsample (BAT *b, n)
* @end multitable
*
- * The routine BATsample returns a random sample on n BUNs of a BAT.
+ * The routine BATsample returns a random sample containing n BUNs of a BAT.
*
*/
gdk_export BAT *BATsample(BAT *b, BUN n);
/*
+ * The routine BATweightedsample returns a weighted random sample containing n
BUNs of a BAT.
+ */
+gdk_export BAT *BATweightedsample(BAT *b, BUN n, BAT *w);
+
+/*
*
*/
#define MAXPARAMS 32
diff --git a/gdk/gdk_sample.c b/gdk/gdk_sample.c
--- a/gdk/gdk_sample.c
+++ b/gdk/gdk_sample.c
@@ -31,6 +31,7 @@
#undef BATsample
+
/* this is a straightforward implementation of a binary tree */
struct oidtreenode {
oid o;
@@ -81,21 +82,31 @@ OIDTreeToBATAntiset(struct oidtreenode *
oid noid;
if (node->left != NULL)
- OIDTreeToBATAntiset(node->left, bat, start, node->o);
+ OIDTreeToBATAntiset(node->left, bat, start, node->o);
else
for (noid = start; noid < node->o; noid++)
((oid *) bat->T->heap.base)[bat->batFirst +
bat->batCount++] = noid;
if (node->right != NULL)
- OIDTreeToBATAntiset(node->right, bat, node->o + 1, stop);
+ OIDTreeToBATAntiset(node->right, bat, node->o + 1,
stop);
else
- for (noid = node->o+1; noid < stop; noid++)
- ((oid *)
bat->T->heap.base)[bat->batFirst + bat->batCount++] = noid;
+ for (noid = node->o + 1; noid < stop; noid++)
+ ((oid *) bat->T->heap.base)[bat->batFirst +
bat->batCount++] = noid;
}
-/* BATsample implements sampling for void headed BATs */
-BAT *
-BATsample(BAT *b, BUN n)
+/* inorder traversal, gives us a bit BAT */
+/*BAT *bat OIDTreeToBITBAT(struct oidtreenode)
+{
+ //TODO create this function
+}*/
+
+/*
+ * _BATsample is the internal (weighted) sampling function without replacement
+ * If cdf=NULL, an uniform sample is taken
+ * Otherwise it is assumed the cdf increases monotonically
+ */
+static BAT *
+_BATsample(BAT *b, BUN n, BAT *cdf)
{
BAT *bn;
BUN cnt, slen;
@@ -103,6 +114,9 @@ BATsample(BAT *b, BUN n)
struct oidtreenode *tree = NULL;
mtwist *mt_rng;
unsigned int range;
+ dbl random;
+ dbl cdf_max;
+ dbl* cdf_ptr;
BATcheck(b, "BATsample", NULL);
assert(BAThdense(b));
@@ -148,7 +162,6 @@ BATsample(BAT *b, BUN n)
GDKfree(tree);
return NULL;
}
- /* while we do not have enough sample OIDs yet */
/* create and seed Mersenne Twister */
mt_rng = mtwist_new();
@@ -157,16 +170,49 @@ BATsample(BAT *b, BUN n)
range = maxoid - minoid;
- for (rescnt = 0; rescnt < n; rescnt++) {
- oid candoid;
- do {
- /* generate a new random OID in [minoid, maxoid[
- * that is including minoid, excluding maxoid*/
- candoid = (oid) ( minoid +
(mtwist_u32rand(mt_rng)%range) );
- /* if that candidate OID was already
- * generated, try again */
- } while (!OIDTreeMaybeInsert(tree, candoid, rescnt));
+ /* sample OIDs (method depends on w) */
+ if(cdf == NULL) {
+ /* no weights, hence do uniform sampling */
+
+ /* while we do not have enough sample OIDs yet */
+ for (rescnt = 0; rescnt < n; rescnt++) {
+ oid candoid;
+ do {
+ /* generate a new random OID in
[minoid, maxoid[
+ * that is including minoid, excluding
maxoid*/
+ candoid = (oid) ( minoid +
(mtwist_u32rand(mt_rng)%range) );
+ /* if that candidate OID was already
+ * generated, try again */
+ } while (!OIDTreeMaybeInsert(tree, candoid,
rescnt));
+ }
+
+ } else {
+ /* do weighted sampling */
+
+ cdf_ptr = (dbl*) Tloc(cdf, BUNfirst(cdf));
+ if (!antiset)
+ cdf_max = cdf_ptr[cnt-1];
+ else
+ cdf_max = cdf_ptr[0];
+ //TODO how to type/cast cdf_max?
+
+ /* generate candoids, using CDF */
+ for (rescnt = 0; rescnt < n; rescnt++) {
+ oid candoid;
+
+ do {
+ random = mtwist_drand(mt_rng)*cdf_max;
+ /* generate a new random OID in
[minoid, maxoid[
+ * that is including minoid, excluding
maxoid*/
+ /* note that cdf has already been
adjusted for antiset case */
+ candoid = (oid) ( minoid + (oid)
SORTfndfirst(cdf, &random) );
+ /* if that candidate OID was already
+ * generated, try again */
+ } while (!OIDTreeMaybeInsert(tree, candoid,
rescnt));
+ }
}
+
+
if (!antiset) {
OIDTreeToBAT(tree, bn);
} else {
@@ -189,3 +235,96 @@ BATsample(BAT *b, BUN n)
}
return bn;
}
+
+
+/* BATsample takes uniform samples of void headed BATs */
+BAT *
+BATsample(BAT *b, BUN n)
+{
+ return _BATsample(b, n, NULL);
+}
+
+/* BATweightedsample takes weighted samples of void headed BATs */
+/* Note that the type of w should be castable to doubles */
+BAT *
+BATweightedsample(BAT *b, BUN n, BAT *w)
+{
+ BAT* cdf;
+ BAT* sample;
+ dbl* w_ptr;//TODO types of w
+ dbl* cdf_ptr;
+ BUN cnt, i;
+ bit antiset;
+
+ BATcheck(b, "BATsample", NULL);
+ BATcheck(w, "BATsample", NULL);
+
+ ERRORcheck(w->ttype == TYPE_str || w->ttype == TYPE_void,
+ "BATsample: type of weights not
castable to doubles\n", NULL);
+ ERRORcheck(w->ttype != TYPE_dbl,
+ "BATsample: type of weights must be
doubles\n", NULL);//TODO types of w (want to remove this)
+
+ cnt = BATcount(b);
+
+ antiset = n > cnt / 2;
+
+ cdf = BATnew(TYPE_void, TYPE_dbl, cnt, TRANSIENT);
+ BATsetcount(cdf, cnt);
+
+ /* calculate cumilative distribution function */
+ w_ptr = (dbl*) Tloc(w, BUNfirst(w));//TODO support different types w
+ cdf_ptr = (dbl*) Tloc(cdf, BUNfirst(cdf));
+
+ cdf_ptr[0] = (dbl)w_ptr[0];
+ for (i = 1; i < cnt; i++) {
+ if((dbl)w_ptr[i] == dbl_nil) {//TODO fix NULL-test if w can
have different types
+ cdf_ptr[i] = cdf_ptr[i-1];
+ } else {
+ cdf_ptr[i] = ((dbl)w_ptr[i]) + cdf_ptr[i-1];
+ }
+ }
+ if (!antiset) {
+ cdf->tsorted = 1;
+ cdf->trevsorted = cnt <= 1;
+ } else {
+ /* in antiset notation, we have to flip probabilities */
+ for (i = 0; i < cnt; i++) {
+ cdf_ptr[i] = cdf_ptr[cnt-1] - cdf_ptr[i];
+ }
+ cdf->tsorted = cnt <= 1;
+ cdf->trevsorted = 1;
+ }
+
+ /* obtain sample */
+ sample = _BATsample(b, n, cdf);
+
+ BATdelete(cdf);
+
+ return sample;
+}
+
+
+/* BATweightedbitbat creates a bit BAT of length cnt containing n 1s and cnt-n
0s */
+/* Note that the type of w should be castable to doubles */
+/*BAT *
+BATweightedbitbat(BUN cnt, BUN n, BAT *w)
+{
+ BAT* res;
+ res = BATnew(TYPE_void, TYPE_dbl, cnt, TRANSIENT);
+ BATsetcount(res, cnt);
+
+ //Need to adjust _BATsample so it will return a bit BAT with bools
denoting if element is selected
+ //Now it will rather return a subset
+ //TODO rewrite _BATsample to support this, add call to _BATsample
+ //Why did we choose for this UDF notation?
+ //+ easier to implement (no parsing addition)
+ //- slow
+ //- actually yields uglier code
+ //Why implement something like that? Hence we should choose for the
other notation?
+
+
+ return res;
+}
+*/
+
+
diff --git a/monetdb5/modules/mal/sample.c b/monetdb5/modules/mal/sample.c
--- a/monetdb5/modules/mal/sample.c
+++ b/monetdb5/modules/mal/sample.c
@@ -54,10 +54,10 @@
* CREATE FUNCTION mysample ()
* RETURNS TABLE(col a,...)
* BEGIN
- * RETURN
- * SELECT a,...
- * FROM name_table
- * SAMPLE 100;
+ * RETURN
+ * SELECT a,...
+ * FROM name_table
+ * SAMPLE 100;
* end;
*
* and then use function mysample() for example to populate a new table with
@@ -104,3 +104,43 @@ SAMPLEuniform_dbl(bat *r, bat *b, dbl *p
BBPunfix(bb->batCacheid);
return SAMPLEuniform(r, b, &s);
}
+
+str
+SAMPLEweighted(bat *r, bat *b, wrd *s, bat *w) {
+ BAT *br, *bb, *bw;
+
+ if ((bw = BATdescriptor(*w)) == NULL ) {
+ throw(MAL, "sample.subweighted", INTERNAL_BAT_ACCESS);
+ }
+ if ((bb = BATdescriptor(*b)) == NULL ) {
+ throw(MAL, "sample.subweighted", INTERNAL_BAT_ACCESS);
+ }
+ br = BATweightedsample(bb, (BUN) *s, bw);
+ if (br == NULL)
+ throw(MAL, "sample.subweighted", OPERATION_FAILED);
+
+ BBPunfix(bb->batCacheid);
+ BBPkeepref(*r = br->batCacheid);
+ return MAL_SUCCEED;
+}
+
+str
+SAMPLEweighted_dbl(bat *r, bat *b, dbl *p, bat *w) {
+ BAT *bb;
+ double pr = *p;
+ wrd s;
+
+ if ( pr < 0.0 || pr > 1.0 ) {
+ throw(MAL, "sample.subweighted", ILLEGAL_ARGUMENT
+ " p should be between 0 and 1.0" );
+ } else if (pr == 0) {/* special case */
+ s = 0;
+ return SAMPLEweighted(r, b, &s, w);
_______________________________________________
checkin-list mailing list
[email protected]
https://www.monetdb.org/mailman/listinfo/checkin-list