Changeset: 80b784215c1e for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/80b784215c1e
Added Files:
sql/test/BugTracker-2022/Tests/cudf-grouped-aggr.Bug-7285.test
Modified Files:
sql/backends/monet5/UDF/capi/capi.c
sql/test/BugTracker-2022/Tests/All
sql/test/BugTracker-2022/Tests/SingleServer
Branch: default
Log Message:
Added test and fix for bug #7285 Use right BAT for aggr_group.count. At the
same time ignore extra input arguments including extents BAT if that's the case
diffs (195 lines):
diff --git a/sql/backends/monet5/UDF/capi/capi.c
b/sql/backends/monet5/UDF/capi/capi.c
--- a/sql/backends/monet5/UDF/capi/capi.c
+++ b/sql/backends/monet5/UDF/capi/capi.c
@@ -454,11 +454,11 @@ static str CUDFeval(Client cntxt, MalBlk
str *output_names = NULL;
char *msg = MAL_SUCCEED;
node *argnode;
- int seengrp = FALSE;
+ int seengrp = 0;
FILE *f = NULL;
void *handle = NULL;
jitted_function func = NULL;
- int ret;
+ int ret, limit_argc = 0;
FILE *compiler = NULL;
int compiler_return_code;
@@ -579,7 +579,7 @@ static str CUDFeval(Client cntxt, MalBlk
}
// the first unknown argument is the group, we don't really care for the
// rest.
- for (i = pci->retc + ARG_OFFSET; i < (size_t)pci->argc; i++) {
+ for (i = pci->retc + ARG_OFFSET; i < (size_t)pci->argc && !seengrp;
i++) {
if (args[i] == NULL) {
if (!seengrp && grouped) {
args[i] = GDKstrdup("aggr_group");
@@ -587,7 +587,7 @@ static str CUDFeval(Client cntxt, MalBlk
msg = createException(MAL, "cudf.eval",
MAL_MALLOC_FAIL);
goto wrapup;
}
- seengrp = TRUE;
+ seengrp = i; /* Don't be interested in the
extents BAT */
} else {
snprintf(argbuf, sizeof(argbuf), "arg%zu", i -
pci->retc - 1);
args[i] = GDKstrdup(argbuf);
@@ -598,12 +598,14 @@ static str CUDFeval(Client cntxt, MalBlk
}
}
}
+ // the first index where input arguments are not relevant for the C UDF
+ limit_argc = i;
// non-grouped aggregates don't have the group list
// to allow users to write code for both grouped and non-grouped
aggregates
// we create an "aggr_group" BAT for non-grouped aggregates
non_grouped_aggregate = grouped && !seengrp;
- input_count = pci->argc - (pci->retc + ARG_OFFSET);
+ input_count = limit_argc - (pci->retc + ARG_OFFSET);
output_count = pci->retc;
// begin the compilation phase
@@ -613,7 +615,7 @@ static str CUDFeval(Client cntxt, MalBlk
funcname_hash = strHash(funcname);
funcname_hash = funcname_hash % FUNCTION_CACHE_SIZE;
j = 0;
- for (i = 0; i < (size_t)pci->argc; i++) {
+ for (i = 0; i < (size_t)limit_argc; i++) {
if (args[i]) {
j += strlen(args[i]);
}
@@ -644,7 +646,7 @@ static str CUDFeval(Client cntxt, MalBlk
}
}
j = input_count + output_count;
- for (i = 0; i < (size_t)pci->argc; i++) {
+ for (i = 0; i < (size_t)limit_argc; i++) {
if (args[i]) {
size_t len = strlen(args[i]);
memcpy(function_parameters + j, args[i], len);
@@ -826,7 +828,7 @@ static str CUDFeval(Client cntxt, MalBlk
// input/output
// of the function
// first convert the input
- for (i = pci->retc + ARG_OFFSET; i < (size_t)pci->argc; i++) {
+ for (i = pci->retc + ARG_OFFSET; i < (size_t)limit_argc; i++) {
bat_type = !isaBatType(getArgType(mb, pci, i))
? getArgType(mb,
pci, i)
:
getBatType(getArgType(mb, pci, i));
@@ -988,7 +990,7 @@ static str CUDFeval(Client cntxt, MalBlk
}
// create the inputs
argnode = sqlfun ? sqlfun->ops->h : NULL;
- for (i = pci->retc + ARG_OFFSET; i < (size_t)pci->argc; i++) {
+ for (i = pci->retc + ARG_OFFSET; i < (size_t)limit_argc; i++) {
index = i - (pci->retc + ARG_OFFSET);
bat_type = getArgType(mb, pci, i);
if (!isaBatType(bat_type)) {
@@ -1040,6 +1042,17 @@ static str CUDFeval(Client cntxt, MalBlk
GENERATE_BAT_INPUT(input_bats[index], int);
} else if (bat_type == TYPE_oid) {
GENERATE_BAT_INPUT(input_bats[index], oid);
+ // Hack for groups BAT, the count should reflect on the
number of groups and not the number
+ // of rows, so use extents BAT
+ if (i == (size_t)seengrp) {
+ struct cudf_data_struct_oid *t = inputs[index];
+ BAT *ex =
BBPquickdesc(*getArgReference_bat(stk, pci, i + 1));
+ if (!ex) {
+ msg = createException(MAL, "cudf.eval",
RUNTIME_OBJECT_MISSING);
+ goto wrapup;
+ }
+ t->count = BATcount(ex);
+ }
} else if (bat_type == TYPE_lng) {
GENERATE_BAT_INPUT(input_bats[index], lng);
} else if (bat_type == TYPE_flt) {
@@ -1602,7 +1615,7 @@ wrapup:
}
// argument names (input)
if (args) {
- for (i = 0; i < (size_t)pci->argc; i++) {
+ for (i = 0; i < (size_t)limit_argc; i++) {
if (args[i]) {
GDKfree(args[i]);
}
diff --git a/sql/test/BugTracker-2022/Tests/All
b/sql/test/BugTracker-2022/Tests/All
--- a/sql/test/BugTracker-2022/Tests/All
+++ b/sql/test/BugTracker-2022/Tests/All
@@ -8,3 +8,4 @@ pkey-restart.Bug-7263
delete-update.Bug-7267
having-clauses.Bug-7278
dump-table-data.Bug-7282
+NOT_WIN32?cudf-grouped-aggr.Bug-7285
diff --git a/sql/test/BugTracker-2022/Tests/SingleServer
b/sql/test/BugTracker-2022/Tests/SingleServer
--- a/sql/test/BugTracker-2022/Tests/SingleServer
+++ b/sql/test/BugTracker-2022/Tests/SingleServer
@@ -1,1 +1,3 @@
--set embedded_py=3
+--set embedded_c=true
+--set capi_cc='cc -std=c99'
diff --git a/sql/test/BugTracker-2022/Tests/cudf-grouped-aggr.Bug-7285.test
b/sql/test/BugTracker-2022/Tests/cudf-grouped-aggr.Bug-7285.test
new file mode 100644
--- /dev/null
+++ b/sql/test/BugTracker-2022/Tests/cudf-grouped-aggr.Bug-7285.test
@@ -0,0 +1,61 @@
+statement ok
+START TRANSACTION
+
+statement ok
+CREATE AGGREGATE jit_sum(input INTEGER) RETURNS BIGINT LANGUAGE C {
+ // initialize one aggregate per group
+ result->initialize(result, aggr_group.count);
+ // zero initialize the sums
+ memset(result->data, 0, result->count * sizeof(result->null_value));
+ // gather the sums for each of the groups
+ for(size_t i = 0; i < input.count; i++) {
+ result->data[aggr_group.data[i]] += input.data[i];
+ }
+}
+
+statement ok
+CREATE TABLE tab(x INTEGER, y INTEGER)
+
+statement ok rowcount 7
+insert into tab values (1,1),(2,2),(3,3),(1,10),(2,50),(3,6),(1,100000)
+
+query I rowsort
+SELECT jit_sum(y) FROM tab
+----
+100072
+
+query II rowsort
+SELECT sum(y), jit_sum(y) FROM tab
+----
+100072
+100072
+
+query I rowsort
+SELECT sum(y) FROM tab GROUP BY x
+----
+100011
+52
+9
+
+query I rowsort
+SELECT jit_sum(y) FROM tab GROUP BY x
+----
+100011
+52
+9
+
+query III rowsort
+SELECT x, sum(y), jit_sum(y) FROM tab GROUP BY x
+----
+1
+100011
+100011
+2
+52
+52
+3
+9
+9
+
+statement ok
+ROLLBACK
_______________________________________________
checkin-list mailing list -- [email protected]
To unsubscribe send an email to [email protected]