Changeset: b97b03897b7a for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/b97b03897b7a
Modified Files:
        monetdb5/extras/rapi/rapi.c
        sql/backends/monet5/UDF/capi/capi.c
        sql/backends/monet5/UDF/pyapi3/pyapi3.c
Branch: Jul2021
Log Message:

Fix refcount on output bats on error


diffs (225 lines):

diff --git a/monetdb5/extras/rapi/rapi.c b/monetdb5/extras/rapi/rapi.c
--- a/monetdb5/extras/rapi/rapi.c
+++ b/monetdb5/extras/rapi/rapi.c
@@ -546,35 +546,54 @@ static char *RAPIinstalladdons(void) {
 static str
 empty_return(MalBlkPtr mb, MalStkPtr stk, InstrPtr pci, size_t retcols, oid 
seqbase)
 {
+       str msg = MAL_SUCCEED;
+       void **res = GDKzalloc(retcols * sizeof(void*));
+
+       if (!res) {
+               msg = createException(MAL, "pyapi3.eval", SQLSTATE(HY013) 
MAL_MALLOC_FAIL);
+               goto bailout;
+       }
+
        for (size_t i = 0; i < retcols; i++) {
                if (isaBatType(getArgType(mb, pci, i))) {
                        BAT *b = COLnew(seqbase, getBatType(getArgType(mb, pci, 
i)), 0, TRANSIENT);
                        if (!b) {
-                               for (size_t j = 0; j < i; j++) {
-                                       if (isaBatType(getArgType(mb, pci, j)))
-                                               
BBPunfix(*getArgReference_bat(stk, pci, j));
-                                       else
-                                               
VALclear(&stk->stk[pci->argv[j]]);
-                               }
-                               return createException(MAL, "rapi.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                               msg = createException(MAL, "pyapi3.eval", 
GDK_EXCEPTION);
+                               goto bailout;
                        }
-                       *getArgReference_bat(stk, pci, i) = b->batCacheid;
-                       BBPkeepref(b->batCacheid);
+                       ((BAT**)res)[i] = b;
                } else { // single value return, only for non-grouped 
aggregations
                        // return NULL to conform to SQL aggregates
                        int tpe = getArgType(mb, pci, i);
                        if (!VALinit(&stk->stk[pci->argv[i]], tpe, 
ATOMnilptr(tpe))) {
-                               for (size_t j = 0; j < i; j++) {
-                                       if (isaBatType(getArgType(mb, pci, j)))
-                                               
BBPunfix(*getArgReference_bat(stk, pci, j));
-                                       else
-                                               
VALclear(&stk->stk[pci->argv[j]]);
+                               msg = createException(MAL, "pyapi3.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                               goto bailout;
+                       }
+                       ((ValPtr*)res)[i] = &stk->stk[pci->argv[i]];
+               }
+       }
+
+bailout:
+       if (res) {
+               for (size_t i = 0; i < retcols; i++) {
+                       if (isaBatType(getArgType(mb, pci, i))) {
+                               BAT *b = ((BAT**)res)[i];
+
+                               if (b && msg) {
+                                       BBPreclaim(b);
+                               } else if (b) {
+                                       BBPkeepref(*getArgReference_bat(stk, 
pci, i) = b->batCacheid);
                                }
-                               return createException(MAL, "rapi.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                       } else if (msg) {
+                               ValPtr pt = ((ValPtr*)res)[i];
+
+                               if (pt)
+                                       VALclear(pt);
                        }
                }
+               GDKfree(res);
        }
-       return MAL_SUCCEED;
+       return msg;
 }
 
 static str RAPIeval(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci, 
bit grouped) {
diff --git a/sql/backends/monet5/UDF/capi/capi.c 
b/sql/backends/monet5/UDF/capi/capi.c
--- a/sql/backends/monet5/UDF/capi/capi.c
+++ b/sql/backends/monet5/UDF/capi/capi.c
@@ -382,35 +382,54 @@ static char valid_path_characters[] = "a
 static str
 empty_return(MalBlkPtr mb, MalStkPtr stk, InstrPtr pci, size_t retcols, oid 
seqbase)
 {
+       str msg = MAL_SUCCEED;
+       void **res = GDKzalloc(retcols * sizeof(void*));
+
+       if (!res) {
+               msg = createException(MAL, "pyapi3.eval", SQLSTATE(HY013) 
MAL_MALLOC_FAIL);
+               goto bailout;
+       }
+
        for (size_t i = 0; i < retcols; i++) {
                if (isaBatType(getArgType(mb, pci, i))) {
                        BAT *b = COLnew(seqbase, getBatType(getArgType(mb, pci, 
i)), 0, TRANSIENT);
                        if (!b) {
-                               for (size_t j = 0; j < i; j++) {
-                                       if (isaBatType(getArgType(mb, pci, j)))
-                                               
BBPunfix(*getArgReference_bat(stk, pci, j));
-                                       else
-                                               
VALclear(&stk->stk[pci->argv[j]]);
-                               }
-                               return createException(MAL, "cudf.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                               msg = createException(MAL, "pyapi3.eval", 
GDK_EXCEPTION);
+                               goto bailout;
                        }
-                       *getArgReference_bat(stk, pci, i) = b->batCacheid;
-                       BBPkeepref(b->batCacheid);
+                       ((BAT**)res)[i] = b;
                } else { // single value return, only for non-grouped 
aggregations
                        // return NULL to conform to SQL aggregates
                        int tpe = getArgType(mb, pci, i);
                        if (!VALinit(&stk->stk[pci->argv[i]], tpe, 
ATOMnilptr(tpe))) {
-                               for (size_t j = 0; j < i; j++) {
-                                       if (isaBatType(getArgType(mb, pci, j)))
-                                               
BBPunfix(*getArgReference_bat(stk, pci, j));
-                                       else
-                                               
VALclear(&stk->stk[pci->argv[j]]);
+                               msg = createException(MAL, "pyapi3.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                               goto bailout;
+                       }
+                       ((ValPtr*)res)[i] = &stk->stk[pci->argv[i]];
+               }
+       }
+
+bailout:
+       if (res) {
+               for (size_t i = 0; i < retcols; i++) {
+                       if (isaBatType(getArgType(mb, pci, i))) {
+                               BAT *b = ((BAT**)res)[i];
+
+                               if (b && msg) {
+                                       BBPreclaim(b);
+                               } else if (b) {
+                                       BBPkeepref(*getArgReference_bat(stk, 
pci, i) = b->batCacheid);
                                }
-                               return createException(MAL, "cudf.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                       } else if (msg) {
+                               ValPtr pt = ((ValPtr*)res)[i];
+
+                               if (pt)
+                                       VALclear(pt);
                        }
                }
+               GDKfree(res);
        }
-       return MAL_SUCCEED;
+       return msg;
 }
 
 static str CUDFeval(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci,
diff --git a/sql/backends/monet5/UDF/pyapi3/pyapi3.c 
b/sql/backends/monet5/UDF/pyapi3/pyapi3.c
--- a/sql/backends/monet5/UDF/pyapi3/pyapi3.c
+++ b/sql/backends/monet5/UDF/pyapi3/pyapi3.c
@@ -1636,37 +1636,56 @@ wrapup:
 }
 
 static str CreateEmptyReturn(MalBlkPtr mb, MalStkPtr stk, InstrPtr pci,
-                                                         size_t retcols, oid 
seqbase)
+                                                        size_t retcols, oid 
seqbase)
 {
+       str msg = MAL_SUCCEED;
+       void **res = GDKzalloc(retcols * sizeof(void*));
+
+       if (!res) {
+               msg = createException(MAL, "pyapi3.eval", SQLSTATE(HY013) 
MAL_MALLOC_FAIL);
+               goto bailout;
+       }
+
        for (size_t i = 0; i < retcols; i++) {
                if (isaBatType(getArgType(mb, pci, i))) {
                        BAT *b = COLnew(seqbase, getBatType(getArgType(mb, pci, 
i)), 0, TRANSIENT);
                        if (!b) {
-                               for (size_t j = 0; j < i; j++) {
-                                       if (isaBatType(getArgType(mb, pci, j)))
-                                               
BBPunfix(*getArgReference_bat(stk, pci, j));
-                                       else
-                                               
VALclear(&stk->stk[pci->argv[j]]);
-                               }
-                               return createException(MAL, "pyapi3.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                               msg = createException(MAL, "pyapi3.eval", 
GDK_EXCEPTION);
+                               goto bailout;
                        }
-                       *getArgReference_bat(stk, pci, i) = b->batCacheid;
-                       BBPkeepref(b->batCacheid);
+                       ((BAT**)res)[i] = b;
                } else { // single value return, only for non-grouped 
aggregations
                        // return NULL to conform to SQL aggregates
                        int tpe = getArgType(mb, pci, i);
                        if (!VALinit(&stk->stk[pci->argv[i]], tpe, 
ATOMnilptr(tpe))) {
-                               for (size_t j = 0; j < i; j++) {
-                                       if (isaBatType(getArgType(mb, pci, j)))
-                                               
BBPunfix(*getArgReference_bat(stk, pci, j));
-                                       else
-                                               
VALclear(&stk->stk[pci->argv[j]]);
+                               msg = createException(MAL, "pyapi3.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                               goto bailout;
+                       }
+                       ((ValPtr*)res)[i] = &stk->stk[pci->argv[i]];
+               }
+       }
+
+bailout:
+       if (res) {
+               for (size_t i = 0; i < retcols; i++) {
+                       if (isaBatType(getArgType(mb, pci, i))) {
+                               BAT *b = ((BAT**)res)[i];
+
+                               if (b && msg) {
+                                       BBPreclaim(b);
+                               } else if (b) {
+                                       BBPkeepref(*getArgReference_bat(stk, 
pci, i) = b->batCacheid);
                                }
-                               return createException(MAL, "pyapi3.eval", 
SQLSTATE(HY013) MAL_MALLOC_FAIL);
+                       } else if (msg) {
+                               ValPtr pt = ((ValPtr*)res)[i];
+
+                               if (pt)
+                                       VALclear(pt);
                        }
                }
+               GDKfree(res);
        }
-       return MAL_SUCCEED;
+       return msg;
 }
 
 #include "mel.h"
_______________________________________________
checkin-list mailing list
[email protected]
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to