Changeset: e5c5d2f2a779 for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/e5c5d2f2a779
Modified Files:
clients/Tests/MAL-signatures-hge.test
clients/Tests/MAL-signatures.test
clients/Tests/exports.stable.out
gdk/gdk.h
gdk/gdk_firstn.c
monetdb5/modules/kernel/algebra.c
Branch: default
Log Message:
Implemented algebra.groupedfirstn/BATgroupedfirstn to do grouped top-N.
diffs (truncated from 338 to 300 lines):
diff --git a/clients/Tests/MAL-signatures-hge.test
b/clients/Tests/MAL-signatures-hge.test
--- a/clients/Tests/MAL-signatures-hge.test
+++ b/clients/Tests/MAL-signatures-hge.test
@@ -3444,6 +3444,11 @@ command algebra.groupby(X_0:bat[:oid], X
ALGgroupby;
Produces a new BAT with groups identified by the head column. The result
contains tail times the head value, ie the tail contains the result group sizes.
algebra
+groupedfirstn
+pattern algebra.groupedfirstn(X_0:lng, X_1:bat[:oid], X_2:bat[:oid],
X_3:any...):bat[:oid]
+ALGgroupedfirstn;
+Grouped firstn
+algebra
intersect
command algebra.intersect(X_0:bat[:any_1], X_1:bat[:any_1], X_2:bat[:oid],
X_3:bat[:oid], X_4:bit, X_5:bit, X_6:lng):bat[:oid]
ALGintersect;
diff --git a/clients/Tests/MAL-signatures.test
b/clients/Tests/MAL-signatures.test
--- a/clients/Tests/MAL-signatures.test
+++ b/clients/Tests/MAL-signatures.test
@@ -2879,6 +2879,11 @@ command algebra.groupby(X_0:bat[:oid], X
ALGgroupby;
Produces a new BAT with groups identified by the head column. The result
contains tail times the head value, ie the tail contains the result group sizes.
algebra
+groupedfirstn
+pattern algebra.groupedfirstn(X_0:lng, X_1:bat[:oid], X_2:bat[:oid],
X_3:any...):bat[:oid]
+ALGgroupedfirstn;
+Grouped firstn
+algebra
intersect
command algebra.intersect(X_0:bat[:any_1], X_1:bat[:any_1], X_2:bat[:oid],
X_3:bat[:oid], X_4:bit, X_5:bit, X_6:lng):bat[:oid]
ALGintersect;
diff --git a/clients/Tests/exports.stable.out b/clients/Tests/exports.stable.out
--- a/clients/Tests/exports.stable.out
+++ b/clients/Tests/exports.stable.out
@@ -135,6 +135,7 @@ BAT *BATgroupcorrelation(BAT *b1, BAT *b
BAT *BATgroupcount(BAT *b, BAT *g, BAT *e, BAT *s, int tp, bool skip_nils);
BAT *BATgroupcovariance_population(BAT *b1, BAT *b2, BAT *g, BAT *e, BAT *s,
int tp, bool skip_nils);
BAT *BATgroupcovariance_sample(BAT *b1, BAT *b2, BAT *g, BAT *e, BAT *s, int
tp, bool skip_nils);
+BAT *BATgroupedfirstn(BUN n, BAT *s, BAT *g, int nbats, BAT **bats, bool *asc,
bool *nilslast) __attribute__((__warn_unused_result__));
BAT *BATgroupmax(BAT *b, BAT *g, BAT *e, BAT *s, int tp, bool skip_nils);
BAT *BATgroupmedian(BAT *b, BAT *g, BAT *e, BAT *s, int tp, bool skip_nils);
BAT *BATgroupmedian_avg(BAT *b, BAT *g, BAT *e, BAT *s, int tp, bool
skip_nils);
diff --git a/gdk/gdk.h b/gdk/gdk.h
--- a/gdk/gdk.h
+++ b/gdk/gdk.h
@@ -2340,6 +2340,8 @@ gdk_export gdk_return BATfirstn(BAT **to
__attribute__((__access__(write_only, 1)))
__attribute__((__access__(write_only, 2)))
__attribute__((__warn_unused_result__));
+gdk_export BAT *BATgroupedfirstn(BUN n, BAT *s, BAT *g, int nbats, BAT **bats,
bool *asc, bool *nilslast)
+ __attribute__((__warn_unused_result__));
#include "gdk_calc.h"
diff --git a/gdk/gdk_firstn.c b/gdk/gdk_firstn.c
--- a/gdk/gdk_firstn.c
+++ b/gdk/gdk_firstn.c
@@ -1328,3 +1328,189 @@ BATfirstn(BAT **topn, BAT **gids, BAT *b
bat_iterator_end(&bi);
return rc;
}
+
+/* Calculate the first N values for each group given in G of the bats in
+ * BATS (of which there are NBATS), but only considering the candidates
+ * in S.
+ *
+ * Conceptually, the bats in BATS are sorted per group, taking the
+ * candidate list S and the values in ASC and NILSLAST into account.
+ * For each group, the first N rows are then returned.
+ *
+ * For each bat, the sort order that is to be used is specified in the
+ * array ASC. The first N values means the smallest N values if asc is
+ * set, the largest if not set. If NILSLAST for a bat is set, nils are
+ * only returned if there are not enough non-nil values; if nilslast is
+ * not set, nils are returned preferentially.
+ *
+ * The return value is a bat with N consecutive values for each group in
+ * G. Values are nil if there are not enough values in the group, else
+ * they are row ids of the first rows.
+ */
+BAT *
+BATgroupedfirstn(BUN n, BAT *s, BAT *g, int nbats, BAT **bats, bool *asc, bool
*nilslast)
+{
+ const char *err;
+ oid min, max;
+ BUN ngrp;
+ struct canditer ci;
+ QryCtx *qry_ctx = MT_thread_get_qry_ctx();
+ struct batinfo {
+ BATiter bi1;
+ BATiter bi2;
+ oid hseq;
+ bool asc;
+ bool nilslast;
+ const void *nil;
+ int (*cmp)(const void *, const void *);
+ } *batinfo;
+
+ assert(nbats > 0);
+
+ if (n == 0 || BATcount(bats[0]) == 0) {
+ return BATdense(0, 0, 0);
+ }
+
+ if ((err = BATgroupaggrinit(bats[0], g, NULL /* e */, s, &min, &max,
&ngrp, &ci)) != NULL) {
+ GDKerror("%s\n", err);
+ return NULL;
+ }
+
+ batinfo = GDKmalloc(nbats * sizeof(struct batinfo));
+ if (batinfo == NULL)
+ return NULL;
+
+ BAT *bn = BATconstant(0, TYPE_oid, &oid_nil, ngrp * n, TRANSIENT);
+ if (bn == NULL) {
+ GDKfree(batinfo);
+ return NULL;
+ }
+
+ for (int i = 0; i < nbats; i++) {
+ batinfo[i] = (struct batinfo) {
+ .bi1 = bat_iterator(bats[i]),
+ .bi2 = bat_iterator(bats[i]),
+ .asc = asc ? asc[i] : false,
+ .nilslast = nilslast ? nilslast[i] : true,
+ .cmp = ATOMcompare(bats[i]->ttype),
+ .hseq = bats[i]->hseqbase,
+ .nil = ATOMnilptr(bats[i]->ttype),
+ };
+ }
+
+ /* For each group we maintain a "heap" of N values inside the
+ * return bat BN. The heap for group GRP is located in BN at
+ * BUN [grp*N..(grp+1)*N). The first value in this heap is the
+ * "largest" (assuming all ASC bits are set) value so far. */
+ oid *oids = Tloc(bn, 0);
+ TIMEOUT_LOOP(ci.ncand, qry_ctx) {
+ oid o = canditer_next(&ci);
+ oid grp = g ? BUNtoid(g, o - g->hseqbase) : 0;
+ BUN goff = grp * n;
+ int comp = -1;
+ if (!is_oid_nil(oids[goff])) {
+ for (int i = 0; i < nbats; i++) {
+ comp = batinfo[i].cmp(BUNtail(batinfo[i].bi1, o
- batinfo[i].hseq),
+ BUNtail(batinfo[i].bi2,
oids[goff] - batinfo[i].hseq));
+ if (comp == 0)
+ continue;
+ if (!batinfo[i].bi1.nonil &&
+ batinfo[i].cmp(BUNtail(batinfo[i].bi1, o -
batinfo[i].hseq),
+ batinfo[i].nil) == 0) {
+ if (batinfo[i].nilslast)
+ comp = 1;
+ else
+ comp = -1;
+ } else if (!batinfo[i].asc)
+ comp = -comp;
+ break;
+ }
+ }
+ /* at this point, if comp==0, the incoming value is
+ * equal to what we currently have as the last of the
+ * first-n and so we skip it; if comp<0, the incoming
+ * value is better than the worst so far, so it replaces
+ * that one, and if comp>0, the incoming value is
+ * definitely not in the first-n */
+ if (comp >= 0)
+ continue;
+ oids[goff] = o;
+ BUN pos = 0;
+ BUN childpos = 1;
+ while (childpos < n) {
+ /* find most extreme child */
+ if (childpos + 1 < n && !is_oid_nil(oids[goff +
childpos])) {
+ if (is_oid_nil(oids[goff + childpos + 1]))
+ childpos++;
+ else {
+ for (int i = 0; i < nbats; i++) {
+ if ((comp =
batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + childpos] - batinfo[i].hseq),
+
BUNtail(batinfo[i].bi2, oids[goff + childpos + 1] - batinfo[i].hseq))) == 0)
+ continue;
+ if (!batinfo[i].bi1.nonil) {
+ if
(batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + childpos] -
batinfo[i].hseq), batinfo[i].nil) == 0) {
+ if
(!batinfo[i].nilslast)
+
childpos++;
+ break;
+ }
+ if
(batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + childpos + 1] -
batinfo[i].hseq), batinfo[i].nil) == 0) {
+ if
(batinfo[i].nilslast)
+
childpos++;
+ break;
+ }
+ }
+ if (batinfo[i].asc ? comp < 0 :
comp > 0)
+ childpos++;
+ break;
+ }
+ }
+ }
+ /* compare parent with most extreme child */
+ if (!is_oid_nil(oids[goff + childpos])) {
+ for (int i = 0; i < nbats; i++) {
+ if ((comp =
batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + pos] - batinfo[i].hseq),
+
BUNtail(batinfo[i].bi2, oids[goff + childpos] - batinfo[i].hseq))) == 0)
+ continue;
+ if (!batinfo[i].bi1.nonil) {
+ if
(batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + pos] - batinfo[i].hseq),
batinfo[i].nil) == 0) {
+ if
(batinfo[i].nilslast) {
+ comp = 0;
+ break;
+ }
+ }
+ if
(batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + childpos] -
batinfo[i].hseq), batinfo[i].nil) == 0) {
+ if
(!batinfo[i].nilslast) {
+ comp = 0;
+ break;
+ }
+ }
+ }
+ if (batinfo[i].asc ? comp > 0 : comp <
0) {
+ comp = 0;
+ }
+ break;
+ }
+ if (comp == 0) {
+ /* already correctly ordered */
+ break;
+ }
+ }
+ oid o = oids[goff + pos];
+ oids[goff + pos] = oids[goff + childpos];
+ oids[goff + childpos] = o;
+ pos = childpos;
+ childpos = (pos << 1) + 1;
+ }
+ }
+ for (int i = 0; i < nbats; i++) {
+ bat_iterator_end(&batinfo[i].bi1);
+ bat_iterator_end(&batinfo[i].bi2);
+ }
+ GDKfree(batinfo);
+ TIMEOUT_CHECK(qry_ctx, GOTO_LABEL_TIMEOUT_HANDLER(bailout, qry_ctx));
+ return bn;
+
+ bailout:
+ BBPreclaim(bn);
+ return NULL;
+}
diff --git a/monetdb5/modules/kernel/algebra.c
b/monetdb5/modules/kernel/algebra.c
--- a/monetdb5/modules/kernel/algebra.c
+++ b/monetdb5/modules/kernel/algebra.c
@@ -992,6 +992,80 @@ ALGfirstn(Client cntxt, MalBlkPtr mb, Ma
}
static str
+ALGgroupedfirstn(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci)
+{
+ bat *ret;
+ bat sid, gid;
+ BAT *s = NULL, *g = NULL;
+ BAT *bn = NULL;
+ lng n;
+
+ (void) cntxt;
+ (void) mb;
+
+ n = *getArgReference_lng(stk, pci, 1);
+ if (n < 0)
+ throw(MAL, "algebra.groupedfirstn", ILLEGAL_ARGUMENT);
+ ret = getArgReference_bat(stk, pci, 0);
+ sid = *getArgReference_bat(stk, pci, 2);
+ gid = *getArgReference_bat(stk, pci, 3);
+ int nbats = pci->argc - 4;
+ if (nbats % 3 != 0)
+ throw(MAL, "algebra.groupedfirstn", ILLEGAL_ARGUMENT);
+ nbats /= 3;
+ BAT **bats = GDKmalloc(nbats * sizeof(BAT *));
+ bool *ascs = GDKmalloc(nbats * sizeof(bool));
+ bool *nlss = GDKmalloc(nbats * sizeof(bool));
+ if (bats == NULL || ascs == NULL || nlss == NULL) {
+ GDKfree(bats);
+ GDKfree(ascs);
+ GDKfree(nlss);
+ throw(MAL, "algebra.groupedfirstn", MAL_MALLOC_FAIL);
+ }
+ if (!is_bat_nil(sid) && (s = BATdescriptor(sid)) == NULL) {
+ GDKfree(bats);
+ GDKfree(ascs);
+ GDKfree(nlss);
+ throw(MAL, "algebra.groupedfirstn", SQLSTATE(HY002)
RUNTIME_OBJECT_MISSING);
+ }
+ if (!is_bat_nil(gid) && (g = BATdescriptor(gid)) == NULL) {
+ BBPreclaim(s);
+ GDKfree(bats);
+ GDKfree(ascs);
+ GDKfree(nlss);
+ throw(MAL, "algebra.groupedfirstn", SQLSTATE(HY002)
RUNTIME_OBJECT_MISSING);
+ }
+ for (int i = 0; i < nbats; i++) {
+ bats[i] = BATdescriptor(*getArgReference_bat(stk, pci, i * 3 +
4));
+ if (bats[i] == NULL) {
+ while (i > 0)
_______________________________________________
checkin-list mailing list -- [email protected]
To unsubscribe send an email to [email protected]