Changeset: c5b67bdd8bcd for MonetDB
URL: http://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=c5b67bdd8bcd
Modified Files:
        sql/backends/monet5/UDF/capi/capi.c
Branch: jitudf
Log Message:

When mprotect triggers, check if the address actually falls in the area we 
protected.


diffs (192 lines):

diff --git a/sql/backends/monet5/UDF/capi/capi.c 
b/sql/backends/monet5/UDF/capi/capi.c
--- a/sql/backends/monet5/UDF/capi/capi.c
+++ b/sql/backends/monet5/UDF/capi/capi.c
@@ -27,8 +27,24 @@ typedef struct _allocated_region {
        struct _allocated_region *next;
 } allocated_region;
 
-static __thread allocated_region *allocated_regions;
-static __thread jmp_buf jump_buffer;
+struct _mprotected_region;
+typedef struct _mprotected_region {
+       void *addr;
+       size_t len;
+
+       void* actual_addr;
+       size_t actual_len;
+
+       struct _mprotected_region *next;
+} mprotected_region;
+
+static char *mprotect_region(void *addr, size_t len, int flags,
+                                                        mprotected_region 
**regions);
+static char *clear_mprotect(void *addr, size_t len);
+
+static allocated_region *allocated_regions[THREADS];
+static mprotected_region **actual_mprotected_regions[THREADS];
+static jmp_buf jump_buffer[THREADS];
 
 typedef char *(*jitted_function)(void **inputs, void **outputs,
                                                                 
malloc_function_ptr malloc);
@@ -81,19 +97,49 @@ static bool WriteTextToFile(FILE *f, con
 
 static void handler(int sig, siginfo_t *si, void *unused)
 {
+       int actually_protected_area = false;
+       mprotected_region* found_region = NULL;
+       int tid = THRgettid();
+
        (void)sig;
-       (void)si;
        (void)unused;
        // we caught a segfault or bus error
-       longjmp(jump_buffer, 1);
-}
+       // this can be either because 
+       // (1) the function accessed a protected piece of memory
+       // (2) the function caused a segfault by e.g. dereferencing a NULL 
pointer
+       // in the first case, this *might* be a valid memory access
+       // this is because we are required to align our mprotects on page 
boundaries
+       // thus sometimes we mprotect a page where only part of the page
+       // should actually be protected. Thus for this case we check if the 
access
+       // was actually an error
 
-struct _mprotected_region;
-typedef struct _mprotected_region {
-       void *addr;
-       size_t len;
-       struct _mprotected_region *next;
-} mprotected_region;
+       if (actual_mprotected_regions[tid]) {
+               mprotected_region* region = *actual_mprotected_regions[tid];
+               while(region) {
+                       if (si->si_addr >= region->addr && (char*) si->si_addr 
<= (char*) region->addr + region->len) {
+                               // the address belongs to this mprotected region
+                               found_region = region;
+                               if (si->si_addr >= region->actual_addr && 
+                                       (char*) si->si_addr <= (char*) 
region->actual_addr + region->actual_len) {
+                                       // and the address is actually supposed 
to be protected
+                                       actually_protected_area = true;
+                                       break;
+                               }
+                       }
+                       region = region->next;
+               }
+       }
+       if (found_region && !actually_protected_area) {
+               // this is NOT an actually protected area
+               // thus the segfault/bus error is invalid
+               // the nasty part here is that we have to unprotect the entire 
page now
+               // thus opening us up to future modifications of the data
+               clear_mprotect(found_region->addr, found_region->len);
+               found_region->addr = NULL;
+       } else {
+               longjmp(jump_buffer[tid], 1);
+       }
+}
 
 static char *mprotect_region(void *addr, size_t len, int flags,
                                                         mprotected_region 
**regions)
@@ -101,6 +147,8 @@ static char *mprotect_region(void *addr,
        mprotected_region *region;
        int pagesize;
        void *page_begin;
+       void* actual_addr = addr;
+       size_t actual_len = len;
        if (len == 0)
                return NULL;
        // check if the region is page-aligned
@@ -126,12 +174,16 @@ static char *mprotect_region(void *addr,
        region->addr = addr;
        region->len = len;
        region->next = *regions;
+       region->actual_addr = actual_addr;
+       region->actual_len = actual_len;
        *regions = region;
        return NULL;
 }
 
 static char *clear_mprotect(void *addr, size_t len)
 {
+       if (!addr) return NULL;
+
        if (mprotect(addr, len, PROT_READ | PROT_WRITE) < 0) {
                return strerror(errno);
        }
@@ -156,7 +208,7 @@ static void *jump_GDK_malloc(size_t size
 {
        void *ptr = GDKmalloc(size);
        if (!ptr) {
-               longjmp(jump_buffer, 2);
+               longjmp(jump_buffer[THRgettid()], 2);
        }
        return ptr;
 }
@@ -165,9 +217,10 @@ static void *wrapped_GDK_malloc(size_t s
 {
        allocated_region *region;
        void *ptr = jump_GDK_malloc(size + sizeof(allocated_region));
+       int tid = THRgettid();
        region = (allocated_region *)ptr;
-       region->next = allocated_regions;
-       allocated_regions = region;
+       region->next = allocated_regions[tid];
+       allocated_regions[tid] = region;
 
        return (char*)ptr + sizeof(allocated_region);
 }
@@ -335,10 +388,12 @@ static str CUDFeval(Client cntxt, MalBlk
        BUN expression_hash = 0, funcname_hash = 0;
        cached_functions* cached_function;
        char* function_parameters = NULL;
+       int tid = THRgettid();
 
        (void)cntxt;
 
-       allocated_regions = NULL;
+       actual_mprotected_regions[tid] = &regions;
+       allocated_regions[tid] = NULL;
 
        // we need to be able to catch segfaults and bus errors
        // so we can work with mprotect to prevent UDFs from changing
@@ -944,7 +999,7 @@ static str CUDFeval(Client cntxt, MalBlk
        // set up a longjmp point
        // this longjmp point is used for some error handling in the C function
        // such as failed mallocs
-       ret = setjmp(jump_buffer);
+       ret = setjmp(jump_buffer[tid]);
        if (ret < 0) {
                // error value
                msg = createException(MAL, "cudf.eval", "Failed setjmp: %s",
@@ -1004,7 +1059,15 @@ static str CUDFeval(Client cntxt, MalBlk
                goto wrapup;
        }
 
-       // FIXME: deal with strings
+       actual_mprotected_regions[tid] = NULL;
+       // clear any mprotected regions
+       while (regions) {
+               mprotected_region *next = regions->next;
+               clear_mprotect(regions->addr, regions->len);
+               GDKfree(regions);
+               regions = next;
+       }
+
        // FIXME: deal with SQL types
 
        // create the output bats from the returned results
@@ -1140,10 +1203,10 @@ wrapup:
                GDKfree(regions);
                regions = next;
        }
-       while (allocated_regions) {
-               allocated_region *next = allocated_regions->next;
-               GDKfree(allocated_regions);
-               allocated_regions = next;
+       while (allocated_regions[tid]) {
+               allocated_region *next = allocated_regions[tid]->next;
+               GDKfree(allocated_regions[tid]);
+               allocated_regions[tid] = next;
        }
        // block segfaults and bus errors again after we exit
        (void)pthread_sigmask(SIG_BLOCK, &signal_set, NULL);
_______________________________________________
checkin-list mailing list
[email protected]
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to