Changeset: 80b784215c1e for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/80b784215c1e
Added Files:
        sql/test/BugTracker-2022/Tests/cudf-grouped-aggr.Bug-7285.test
Modified Files:
        sql/backends/monet5/UDF/capi/capi.c
        sql/test/BugTracker-2022/Tests/All
        sql/test/BugTracker-2022/Tests/SingleServer
Branch: default
Log Message:

Added test and fix for bug #7285 Use right BAT for aggr_group.count. At the 
same time ignore extra input arguments including extents BAT if that's the case


diffs (195 lines):

diff --git a/sql/backends/monet5/UDF/capi/capi.c 
b/sql/backends/monet5/UDF/capi/capi.c
--- a/sql/backends/monet5/UDF/capi/capi.c
+++ b/sql/backends/monet5/UDF/capi/capi.c
@@ -454,11 +454,11 @@ static str CUDFeval(Client cntxt, MalBlk
        str *output_names = NULL;
        char *msg = MAL_SUCCEED;
        node *argnode;
-       int seengrp = FALSE;
+       int seengrp = 0;
        FILE *f = NULL;
        void *handle = NULL;
        jitted_function func = NULL;
-       int ret;
+       int ret, limit_argc = 0;
 
        FILE *compiler = NULL;
        int compiler_return_code;
@@ -579,7 +579,7 @@ static str CUDFeval(Client cntxt, MalBlk
        }
        // the first unknown argument is the group, we don't really care for the
        // rest.
-       for (i = pci->retc + ARG_OFFSET; i < (size_t)pci->argc; i++) {
+       for (i = pci->retc + ARG_OFFSET; i < (size_t)pci->argc && !seengrp; 
i++) {
                if (args[i] == NULL) {
                        if (!seengrp && grouped) {
                                args[i] = GDKstrdup("aggr_group");
@@ -587,7 +587,7 @@ static str CUDFeval(Client cntxt, MalBlk
                                        msg = createException(MAL, "cudf.eval", 
MAL_MALLOC_FAIL);
                                        goto wrapup;
                                }
-                               seengrp = TRUE;
+                               seengrp = i; /* Don't be interested in the 
extents BAT */
                        } else {
                                snprintf(argbuf, sizeof(argbuf), "arg%zu", i - 
pci->retc - 1);
                                args[i] = GDKstrdup(argbuf);
@@ -598,12 +598,14 @@ static str CUDFeval(Client cntxt, MalBlk
                        }
                }
        }
+       // the first index where input arguments are not relevant for the C UDF
+       limit_argc = i;
        // non-grouped aggregates don't have the group list
        // to allow users to write code for both grouped and non-grouped 
aggregates
        // we create an "aggr_group" BAT for non-grouped aggregates
        non_grouped_aggregate = grouped && !seengrp;
 
-       input_count = pci->argc - (pci->retc + ARG_OFFSET);
+       input_count = limit_argc - (pci->retc + ARG_OFFSET);
        output_count = pci->retc;
 
        // begin the compilation phase
@@ -613,7 +615,7 @@ static str CUDFeval(Client cntxt, MalBlk
        funcname_hash = strHash(funcname);
        funcname_hash = funcname_hash % FUNCTION_CACHE_SIZE;
        j = 0;
-       for (i = 0; i < (size_t)pci->argc; i++) {
+       for (i = 0; i < (size_t)limit_argc; i++) {
                if (args[i]) {
                        j += strlen(args[i]);
                }
@@ -644,7 +646,7 @@ static str CUDFeval(Client cntxt, MalBlk
                }
        }
        j = input_count + output_count;
-       for (i = 0; i < (size_t)pci->argc; i++) {
+       for (i = 0; i < (size_t)limit_argc; i++) {
                if (args[i]) {
                        size_t len = strlen(args[i]);
                        memcpy(function_parameters + j, args[i], len);
@@ -826,7 +828,7 @@ static str CUDFeval(Client cntxt, MalBlk
                // input/output
                // of the function
                // first convert the input
-               for (i = pci->retc + ARG_OFFSET; i < (size_t)pci->argc; i++) {
+               for (i = pci->retc + ARG_OFFSET; i < (size_t)limit_argc; i++) {
                        bat_type = !isaBatType(getArgType(mb, pci, i))
                                                           ? getArgType(mb, 
pci, i)
                                                           : 
getBatType(getArgType(mb, pci, i));
@@ -988,7 +990,7 @@ static str CUDFeval(Client cntxt, MalBlk
        }
        // create the inputs
        argnode = sqlfun ? sqlfun->ops->h : NULL;
-       for (i = pci->retc + ARG_OFFSET; i < (size_t)pci->argc; i++) {
+       for (i = pci->retc + ARG_OFFSET; i < (size_t)limit_argc; i++) {
                index = i - (pci->retc + ARG_OFFSET);
                bat_type = getArgType(mb, pci, i);
                if (!isaBatType(bat_type)) {
@@ -1040,6 +1042,17 @@ static str CUDFeval(Client cntxt, MalBlk
                        GENERATE_BAT_INPUT(input_bats[index], int);
                } else if (bat_type == TYPE_oid) {
                        GENERATE_BAT_INPUT(input_bats[index], oid);
+                       // Hack for groups BAT, the count should reflect on the 
number of groups and not the number
+                       // of rows, so use extents BAT
+                       if (i == (size_t)seengrp) {
+                               struct cudf_data_struct_oid *t = inputs[index];
+                               BAT *ex = 
BBPquickdesc(*getArgReference_bat(stk, pci, i + 1));
+                               if (!ex) {
+                                       msg = createException(MAL, "cudf.eval", 
RUNTIME_OBJECT_MISSING);
+                                       goto wrapup;
+                               }
+                               t->count = BATcount(ex);
+                       }
                } else if (bat_type == TYPE_lng) {
                        GENERATE_BAT_INPUT(input_bats[index], lng);
                } else if (bat_type == TYPE_flt) {
@@ -1602,7 +1615,7 @@ wrapup:
        }
        // argument names (input)
        if (args) {
-               for (i = 0; i < (size_t)pci->argc; i++) {
+               for (i = 0; i < (size_t)limit_argc; i++) {
                        if (args[i]) {
                                GDKfree(args[i]);
                        }
diff --git a/sql/test/BugTracker-2022/Tests/All 
b/sql/test/BugTracker-2022/Tests/All
--- a/sql/test/BugTracker-2022/Tests/All
+++ b/sql/test/BugTracker-2022/Tests/All
@@ -8,3 +8,4 @@ pkey-restart.Bug-7263
 delete-update.Bug-7267
 having-clauses.Bug-7278
 dump-table-data.Bug-7282
+NOT_WIN32?cudf-grouped-aggr.Bug-7285
diff --git a/sql/test/BugTracker-2022/Tests/SingleServer 
b/sql/test/BugTracker-2022/Tests/SingleServer
--- a/sql/test/BugTracker-2022/Tests/SingleServer
+++ b/sql/test/BugTracker-2022/Tests/SingleServer
@@ -1,1 +1,3 @@
 --set embedded_py=3
+--set embedded_c=true
+--set capi_cc='cc -std=c99'
diff --git a/sql/test/BugTracker-2022/Tests/cudf-grouped-aggr.Bug-7285.test 
b/sql/test/BugTracker-2022/Tests/cudf-grouped-aggr.Bug-7285.test
new file mode 100644
--- /dev/null
+++ b/sql/test/BugTracker-2022/Tests/cudf-grouped-aggr.Bug-7285.test
@@ -0,0 +1,61 @@
+statement ok
+START TRANSACTION
+
+statement ok
+CREATE AGGREGATE jit_sum(input INTEGER) RETURNS BIGINT LANGUAGE C {
+    // initialize one aggregate per group
+    result->initialize(result, aggr_group.count);
+    // zero initialize the sums
+    memset(result->data, 0, result->count * sizeof(result->null_value));
+    // gather the sums for each of the groups
+    for(size_t i = 0; i < input.count; i++) {
+        result->data[aggr_group.data[i]] += input.data[i];
+    }
+}
+
+statement ok
+CREATE TABLE tab(x INTEGER, y INTEGER)
+
+statement ok rowcount 7
+insert into tab values (1,1),(2,2),(3,3),(1,10),(2,50),(3,6),(1,100000)
+
+query I rowsort
+SELECT jit_sum(y) FROM tab
+----
+100072
+
+query II rowsort
+SELECT sum(y), jit_sum(y) FROM tab
+----
+100072
+100072
+
+query I rowsort
+SELECT sum(y) FROM tab GROUP BY x
+----
+100011
+52
+9
+
+query I rowsort
+SELECT jit_sum(y) FROM tab GROUP BY x
+----
+100011
+52
+9
+
+query III rowsort
+SELECT x, sum(y), jit_sum(y) FROM tab GROUP BY x
+----
+1
+100011
+100011
+2
+52
+52
+3
+9
+9
+
+statement ok
+ROLLBACK
_______________________________________________
checkin-list mailing list -- [email protected]
To unsubscribe send an email to [email protected]

Reply via email to