Changeset: e5c5d2f2a779 for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/e5c5d2f2a779
Modified Files:
        clients/Tests/MAL-signatures-hge.test
        clients/Tests/MAL-signatures.test
        clients/Tests/exports.stable.out
        gdk/gdk.h
        gdk/gdk_firstn.c
        monetdb5/modules/kernel/algebra.c
Branch: default
Log Message:

Implemented algebra.groupedfirstn/BATgroupedfirstn to do grouped top-N.


diffs (truncated from 338 to 300 lines):

diff --git a/clients/Tests/MAL-signatures-hge.test 
b/clients/Tests/MAL-signatures-hge.test
--- a/clients/Tests/MAL-signatures-hge.test
+++ b/clients/Tests/MAL-signatures-hge.test
@@ -3444,6 +3444,11 @@ command algebra.groupby(X_0:bat[:oid], X
 ALGgroupby;
 Produces a new BAT with groups identified by the head column. The result 
contains tail times the head value, ie the tail contains the result group sizes.
 algebra
+groupedfirstn
+pattern algebra.groupedfirstn(X_0:lng, X_1:bat[:oid], X_2:bat[:oid], 
X_3:any...):bat[:oid]
+ALGgroupedfirstn;
+Grouped firstn
+algebra
 intersect
 command algebra.intersect(X_0:bat[:any_1], X_1:bat[:any_1], X_2:bat[:oid], 
X_3:bat[:oid], X_4:bit, X_5:bit, X_6:lng):bat[:oid]
 ALGintersect;
diff --git a/clients/Tests/MAL-signatures.test 
b/clients/Tests/MAL-signatures.test
--- a/clients/Tests/MAL-signatures.test
+++ b/clients/Tests/MAL-signatures.test
@@ -2879,6 +2879,11 @@ command algebra.groupby(X_0:bat[:oid], X
 ALGgroupby;
 Produces a new BAT with groups identified by the head column. The result 
contains tail times the head value, ie the tail contains the result group sizes.
 algebra
+groupedfirstn
+pattern algebra.groupedfirstn(X_0:lng, X_1:bat[:oid], X_2:bat[:oid], 
X_3:any...):bat[:oid]
+ALGgroupedfirstn;
+Grouped firstn
+algebra
 intersect
 command algebra.intersect(X_0:bat[:any_1], X_1:bat[:any_1], X_2:bat[:oid], 
X_3:bat[:oid], X_4:bit, X_5:bit, X_6:lng):bat[:oid]
 ALGintersect;
diff --git a/clients/Tests/exports.stable.out b/clients/Tests/exports.stable.out
--- a/clients/Tests/exports.stable.out
+++ b/clients/Tests/exports.stable.out
@@ -135,6 +135,7 @@ BAT *BATgroupcorrelation(BAT *b1, BAT *b
 BAT *BATgroupcount(BAT *b, BAT *g, BAT *e, BAT *s, int tp, bool skip_nils);
 BAT *BATgroupcovariance_population(BAT *b1, BAT *b2, BAT *g, BAT *e, BAT *s, 
int tp, bool skip_nils);
 BAT *BATgroupcovariance_sample(BAT *b1, BAT *b2, BAT *g, BAT *e, BAT *s, int 
tp, bool skip_nils);
+BAT *BATgroupedfirstn(BUN n, BAT *s, BAT *g, int nbats, BAT **bats, bool *asc, 
bool *nilslast) __attribute__((__warn_unused_result__));
 BAT *BATgroupmax(BAT *b, BAT *g, BAT *e, BAT *s, int tp, bool skip_nils);
 BAT *BATgroupmedian(BAT *b, BAT *g, BAT *e, BAT *s, int tp, bool skip_nils);
 BAT *BATgroupmedian_avg(BAT *b, BAT *g, BAT *e, BAT *s, int tp, bool 
skip_nils);
diff --git a/gdk/gdk.h b/gdk/gdk.h
--- a/gdk/gdk.h
+++ b/gdk/gdk.h
@@ -2340,6 +2340,8 @@ gdk_export gdk_return BATfirstn(BAT **to
        __attribute__((__access__(write_only, 1)))
        __attribute__((__access__(write_only, 2)))
        __attribute__((__warn_unused_result__));
+gdk_export BAT *BATgroupedfirstn(BUN n, BAT *s, BAT *g, int nbats, BAT **bats, 
bool *asc, bool *nilslast)
+       __attribute__((__warn_unused_result__));
 
 #include "gdk_calc.h"
 
diff --git a/gdk/gdk_firstn.c b/gdk/gdk_firstn.c
--- a/gdk/gdk_firstn.c
+++ b/gdk/gdk_firstn.c
@@ -1328,3 +1328,189 @@ BATfirstn(BAT **topn, BAT **gids, BAT *b
        bat_iterator_end(&bi);
        return rc;
 }
+
+/* Calculate the first N values for each group given in G of the bats in
+ * BATS (of which there are NBATS), but only considering the candidates
+ * in S.
+ *
+ * Conceptually, the bats in BATS are sorted per group, taking the
+ * candidate list S and the values in ASC and NILSLAST into account.
+ * For each group, the first N rows are then returned.
+ *
+ * For each bat, the sort order that is to be used is specified in the
+ * array ASC.  The first N values means the smallest N values if asc is
+ * set, the largest if not set.  If NILSLAST for a bat is set, nils are
+ * only returned if there are not enough non-nil values; if nilslast is
+ * not set, nils are returned preferentially.
+ *
+ * The return value is a bat with N consecutive values for each group in
+ * G.  Values are nil if there are not enough values in the group, else
+ * they are row ids of the first rows.
+ */
+BAT *
+BATgroupedfirstn(BUN n, BAT *s, BAT *g, int nbats, BAT **bats, bool *asc, bool 
*nilslast)
+{
+       const char *err;
+       oid min, max;
+       BUN ngrp;
+       struct canditer ci;
+       QryCtx *qry_ctx = MT_thread_get_qry_ctx();
+       struct batinfo {
+               BATiter bi1;
+               BATiter bi2;
+               oid hseq;
+               bool asc;
+               bool nilslast;
+               const void *nil;
+               int (*cmp)(const void *, const void *);
+       } *batinfo;
+
+       assert(nbats > 0);
+
+       if (n == 0 || BATcount(bats[0]) == 0) {
+               return BATdense(0, 0, 0);
+       }
+
+       if ((err = BATgroupaggrinit(bats[0], g, NULL /* e */, s, &min, &max, 
&ngrp, &ci)) != NULL) {
+               GDKerror("%s\n", err);
+               return NULL;
+       }
+
+       batinfo = GDKmalloc(nbats * sizeof(struct batinfo));
+       if (batinfo == NULL)
+               return NULL;
+
+       BAT *bn = BATconstant(0, TYPE_oid, &oid_nil, ngrp * n, TRANSIENT);
+       if (bn == NULL) {
+               GDKfree(batinfo);
+               return NULL;
+       }
+
+       for (int i = 0; i < nbats; i++) {
+               batinfo[i] = (struct batinfo) {
+                       .bi1 = bat_iterator(bats[i]),
+                       .bi2 = bat_iterator(bats[i]),
+                       .asc = asc ? asc[i] : false,
+                       .nilslast = nilslast ? nilslast[i] : true,
+                       .cmp = ATOMcompare(bats[i]->ttype),
+                       .hseq = bats[i]->hseqbase,
+                       .nil = ATOMnilptr(bats[i]->ttype),
+               };
+       }
+
+       /* For each group we maintain a "heap" of N values inside the
+        * return bat BN.  The heap for group GRP is located in BN at
+        * BUN [grp*N..(grp+1)*N).  The first value in this heap is the
+        * "largest" (assuming all ASC bits are set) value so far. */
+       oid *oids = Tloc(bn, 0);
+       TIMEOUT_LOOP(ci.ncand, qry_ctx) {
+               oid o = canditer_next(&ci);
+               oid grp = g ? BUNtoid(g, o - g->hseqbase) : 0;
+               BUN goff = grp * n;
+               int comp = -1;
+               if (!is_oid_nil(oids[goff])) {
+                       for (int i = 0; i < nbats; i++) {
+                               comp = batinfo[i].cmp(BUNtail(batinfo[i].bi1, o 
- batinfo[i].hseq),
+                                                     BUNtail(batinfo[i].bi2, 
oids[goff] - batinfo[i].hseq));
+                               if (comp == 0)
+                                       continue;
+                               if (!batinfo[i].bi1.nonil &&
+                                   batinfo[i].cmp(BUNtail(batinfo[i].bi1, o - 
batinfo[i].hseq),
+                                                  batinfo[i].nil) == 0) {
+                                       if (batinfo[i].nilslast)
+                                               comp = 1;
+                                       else
+                                               comp = -1;
+                               } else if (!batinfo[i].asc)
+                                       comp = -comp;
+                               break;
+                       }
+               }
+               /* at this point, if comp==0, the incoming value is
+                * equal to what we currently have as the last of the
+                * first-n and so we skip it; if comp<0, the incoming
+                * value is better than the worst so far, so it replaces
+                * that one, and if comp>0, the incoming value is
+                * definitely not in the first-n */
+               if (comp >= 0)
+                       continue;
+               oids[goff] = o;
+               BUN pos = 0;
+               BUN childpos = 1;
+               while (childpos < n) {
+                       /* find most extreme child */
+                       if (childpos + 1 < n && !is_oid_nil(oids[goff + 
childpos])) {
+                               if (is_oid_nil(oids[goff + childpos + 1]))
+                                       childpos++;
+                               else {
+                                       for (int i = 0; i < nbats; i++) {
+                                               if ((comp = 
batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + childpos] - batinfo[i].hseq),
+                                                                          
BUNtail(batinfo[i].bi2, oids[goff + childpos + 1] - batinfo[i].hseq))) == 0)
+                                                       continue;
+                                               if (!batinfo[i].bi1.nonil) {
+                                                       if 
(batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + childpos] - 
batinfo[i].hseq), batinfo[i].nil) == 0) {
+                                                               if 
(!batinfo[i].nilslast)
+                                                                       
childpos++;
+                                                               break;
+                                                       }
+                                                       if 
(batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + childpos + 1] - 
batinfo[i].hseq), batinfo[i].nil) == 0) {
+                                                               if 
(batinfo[i].nilslast)
+                                                                       
childpos++;
+                                                               break;
+                                                       }
+                                               }
+                                               if (batinfo[i].asc ? comp < 0 : 
comp > 0)
+                                                       childpos++;
+                                               break;
+                                       }
+                               }
+                       }
+                       /* compare parent with most extreme child */
+                       if (!is_oid_nil(oids[goff + childpos])) {
+                               for (int i = 0; i < nbats; i++) {
+                                       if ((comp = 
batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + pos] - batinfo[i].hseq),
+                                                                  
BUNtail(batinfo[i].bi2, oids[goff + childpos] - batinfo[i].hseq))) == 0)
+                                               continue;
+                                       if (!batinfo[i].bi1.nonil) {
+                                               if 
(batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + pos] - batinfo[i].hseq), 
batinfo[i].nil) == 0) {
+                                                       if 
(batinfo[i].nilslast) {
+                                                               comp = 0;
+                                                               break;
+                                                       }
+                                               }
+                                               if 
(batinfo[i].cmp(BUNtail(batinfo[i].bi1, oids[goff + childpos] - 
batinfo[i].hseq), batinfo[i].nil) == 0) {
+                                                       if 
(!batinfo[i].nilslast) {
+                                                               comp = 0;
+                                                               break;
+                                                       }
+                                               }
+                                       }
+                                       if (batinfo[i].asc ? comp > 0 : comp < 
0) {
+                                               comp = 0;
+                                       }
+                                       break;
+                               }
+                               if (comp == 0) {
+                                       /* already correctly ordered */
+                                       break;
+                               }
+                       }
+                       oid o = oids[goff + pos];
+                       oids[goff + pos] = oids[goff + childpos];
+                       oids[goff + childpos] = o;
+                       pos = childpos;
+                       childpos = (pos << 1) + 1;
+               }
+       }
+       for (int i = 0; i < nbats; i++) {
+               bat_iterator_end(&batinfo[i].bi1);
+               bat_iterator_end(&batinfo[i].bi2);
+       }
+       GDKfree(batinfo);
+       TIMEOUT_CHECK(qry_ctx, GOTO_LABEL_TIMEOUT_HANDLER(bailout, qry_ctx));
+       return bn;
+
+  bailout:
+       BBPreclaim(bn);
+       return NULL;
+}
diff --git a/monetdb5/modules/kernel/algebra.c 
b/monetdb5/modules/kernel/algebra.c
--- a/monetdb5/modules/kernel/algebra.c
+++ b/monetdb5/modules/kernel/algebra.c
@@ -992,6 +992,80 @@ ALGfirstn(Client cntxt, MalBlkPtr mb, Ma
 }
 
 static str
+ALGgroupedfirstn(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci)
+{
+       bat *ret;
+       bat sid, gid;
+       BAT *s = NULL, *g = NULL;
+       BAT *bn = NULL;
+       lng n;
+
+       (void) cntxt;
+       (void) mb;
+
+       n = *getArgReference_lng(stk, pci, 1);
+       if (n < 0)
+               throw(MAL, "algebra.groupedfirstn", ILLEGAL_ARGUMENT);
+       ret = getArgReference_bat(stk, pci, 0);
+       sid = *getArgReference_bat(stk, pci, 2);
+       gid = *getArgReference_bat(stk, pci, 3);
+       int nbats = pci->argc - 4;
+       if (nbats % 3 != 0)
+               throw(MAL, "algebra.groupedfirstn", ILLEGAL_ARGUMENT);
+       nbats /= 3;
+       BAT **bats = GDKmalloc(nbats * sizeof(BAT *));
+       bool *ascs = GDKmalloc(nbats * sizeof(bool));
+       bool *nlss = GDKmalloc(nbats * sizeof(bool));
+       if (bats == NULL || ascs == NULL || nlss == NULL) {
+               GDKfree(bats);
+               GDKfree(ascs);
+               GDKfree(nlss);
+               throw(MAL, "algebra.groupedfirstn", MAL_MALLOC_FAIL);
+       }
+       if (!is_bat_nil(sid) && (s = BATdescriptor(sid)) == NULL) {
+               GDKfree(bats);
+               GDKfree(ascs);
+               GDKfree(nlss);
+               throw(MAL, "algebra.groupedfirstn", SQLSTATE(HY002) 
RUNTIME_OBJECT_MISSING);
+       }
+       if (!is_bat_nil(gid) && (g = BATdescriptor(gid)) == NULL) {
+               BBPreclaim(s);
+               GDKfree(bats);
+               GDKfree(ascs);
+               GDKfree(nlss);
+               throw(MAL, "algebra.groupedfirstn", SQLSTATE(HY002) 
RUNTIME_OBJECT_MISSING);
+       }
+       for (int i = 0; i < nbats; i++) {
+               bats[i] = BATdescriptor(*getArgReference_bat(stk, pci, i * 3 + 
4));
+               if (bats[i] == NULL) {
+                       while (i > 0)
_______________________________________________
checkin-list mailing list -- [email protected]
To unsubscribe send an email to [email protected]

Reply via email to