Changeset: 0a98aecb2ea2 for MonetDB
URL: http://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=0a98aecb2ea2
Modified Files:
        sql/backends/monet5/UDF/capi/Tests/capi02.sql
        sql/backends/monet5/UDF/capi/capi.c
Branch: jitudf
Log Message:

Only mprotect regions after the signal handlers have been set.


diffs (186 lines):

diff --git a/sql/backends/monet5/UDF/capi/Tests/capi02.sql 
b/sql/backends/monet5/UDF/capi/Tests/capi02.sql
--- a/sql/backends/monet5/UDF/capi/Tests/capi02.sql
+++ b/sql/backends/monet5/UDF/capi/Tests/capi02.sql
@@ -45,16 +45,15 @@ DROP TABLE dates;
 CREATE FUNCTION capi02_randomize_time(d TIME) RETURNS TIME
 language C
 {
-       srand(1234);
        result->initialize(result, d.count);
        for(size_t i = 0; i < result->count; i++) {
                if (d.is_null(d.data[i])) {
                        result->data[i] = result->null_value;
                } else {
-                       result->data[i].hours = rand() % 24;
-                       result->data[i].minutes = rand() % 60;
-                       result->data[i].seconds = rand() % 60;
-                       result->data[i].ms = rand() % 1000;
+                       result->data[i].hours = (i + 1234) % 24;
+                       result->data[i].minutes = (i + 1234) % 60;
+                       result->data[i].seconds = (i + 1234) % 60;
+                       result->data[i].ms = (i + 1234) % 1000;
                }
        }
 };
@@ -77,7 +76,6 @@ DROP TABLE times;
 CREATE FUNCTION capi02_increment_timestamp(d TIMESTAMP) RETURNS TIMESTAMP
 language C
 {
-       srand(1234);
        result->initialize(result, d.count);
        for(size_t i = 0; i < result->count; i++) {
                if (d.is_null(d.data[i])) {
@@ -87,10 +85,10 @@ language C
                        result->data[i].date.month = d.data[i].date.month;
                        result->data[i].date.day = d.data[i].date.day;
 
-                       result->data[i].time.hours = rand() % 24;
-                       result->data[i].time.minutes = rand() % 60;
-                       result->data[i].time.seconds = rand() % 60;
-                       result->data[i].time.ms = rand() % 1000;
+                       result->data[i].time.hours = (i + 1234) % 24;
+                       result->data[i].time.minutes = (i + 1234) % 60;
+                       result->data[i].time.seconds = (i + 1234) % 60;
+                       result->data[i].time.ms = (i + 1234) % 1000;
                }
        }
 };
diff --git a/sql/backends/monet5/UDF/capi/capi.c 
b/sql/backends/monet5/UDF/capi/capi.c
--- a/sql/backends/monet5/UDF/capi/capi.c
+++ b/sql/backends/monet5/UDF/capi/capi.c
@@ -36,11 +36,12 @@ typedef struct _mprotected_region {
        void* actual_addr;
        size_t actual_len;
 
+       bool is_protected;
+
        struct _mprotected_region *next;
 } mprotected_region;
 
-static char *mprotect_region(void *addr, size_t len, int flags,
-                                                        mprotected_region 
**regions);
+static char *mprotect_region(void *addr, size_t len, mprotected_region 
**regions);
 static char *clear_mprotect(void *addr, size_t len);
 
 static allocated_region *allocated_regions[THREADS];
@@ -142,8 +143,7 @@ static void handler(int sig, siginfo_t *
        }
 }
 
-static char *mprotect_region(void *addr, size_t len, int flags,
-                                                        mprotected_region 
**regions)
+static char *mprotect_region(void *addr, size_t len, mprotected_region 
**regions)
 {
        mprotected_region *region;
        int pagesize;
@@ -170,15 +170,12 @@ static char *mprotect_region(void *addr,
        if (!region) {
                return MAL_MALLOC_FAIL;
        }
-       if (mprotect(addr, len, flags) < 0) {
-               GDKfree(region);
-               return strerror(errno);
-       }
        region->addr = addr;
        region->len = len;
        region->next = *regions;
        region->actual_addr = actual_addr;
        region->actual_len = actual_len;
+       region->is_protected = false;
        *regions = region;
        return NULL;
 }
@@ -298,7 +295,7 @@ GENERATE_BASE_HEADERS(cudf_data_blob, bl
                        bat_data->data = (tpe *)Tloc(b, 0);                     
               \
                        mprotect_retval = mprotect_region(                      
               \
                                bat_data->data, bat_data->count * 
sizeof(bat_data->null_value),    \
-                               PROT_READ, &regions);                           
                   \
+                               &regions);                                      
        \
                        if (mprotect_retval) {                                  
               \
                                msg = createException(MAL, "cudf.eval",         
                   \
                                                                          
"Failed to mprotect region: %s",             \
@@ -392,7 +389,7 @@ static str CUDFeval(Client cntxt, MalBlk
        void **outputs = NULL;
        size_t output_count = 0;
        BAT **input_bats = NULL;
-       mprotected_region *regions = NULL;
+       mprotected_region *regions = NULL, *region_iter = NULL;
 
        lng initial_output_count = -1;
 
@@ -976,7 +973,7 @@ static str CUDFeval(Client cntxt, MalBlk
                                assert(input_bats[index]->tvheap);
                                mprotect_retval = mprotect_region(
                                        input_bats[index]->tvheap->base,
-                                       input_bats[index]->tvheap->size, 
PROT_READ, &regions);
+                                       input_bats[index]->tvheap->size, 
&regions);
                                if (mprotect_retval) {
                                        msg = createException(MAL, "cudf.eval",
                                                                                
  "Failed to mprotect region: %s",
@@ -1063,7 +1060,7 @@ static str CUDFeval(Client cntxt, MalBlk
                                assert(input_bats[index]->tvheap);
                                mprotect_retval = mprotect_region(
                                        input_bats[index]->tvheap->base,
-                                       input_bats[index]->tvheap->size, 
PROT_READ, &regions);
+                                       input_bats[index]->tvheap->size, 
&regions);
                                if (mprotect_retval) {
                                        msg = createException(MAL, "cudf.eval",
                                                                                
  "Failed to mprotect region: %s",
@@ -1202,9 +1199,31 @@ static str CUDFeval(Client cntxt, MalBlk
                goto wrapup;
        }
 
+       // actually mprotect the regions now that the signal handlers are set
+       region_iter = regions;
+       while(region_iter) {
+               if (mprotect(region_iter->addr, region_iter->len, PROT_READ) < 
0) {
+                       goto wrapup;
+               }
+               region_iter->is_protected = true;
+               region_iter = region_iter->next;
+       }
+
        // call the actual jitted function
        msg = func(inputs, outputs, wrapped_GDK_malloc);
 
+       // clear any mprotected regions
+       while (regions) {
+               mprotected_region *next = regions->next;
+               if (regions->is_protected) {
+                       clear_mprotect(regions->addr, regions->len);    
+               }
+               GDKfree(regions);
+               regions = next;
+       }
+       
+       actual_mprotected_regions[tid] = NULL;
+
        // clear the signal handlers
        if (sigaction(SIGSEGV, &oldsa, NULL) == -1 ||
                sigaction(SIGBUS, &oldsb, NULL) == -1) {
@@ -1222,14 +1241,6 @@ static str CUDFeval(Client cntxt, MalBlk
                goto wrapup;
        }
 
-       actual_mprotected_regions[tid] = NULL;
-       // clear any mprotected regions
-       while (regions) {
-               mprotected_region *next = regions->next;
-               clear_mprotect(regions->addr, regions->len);
-               GDKfree(regions);
-               regions = next;
-       }
 
        // create the output bats from the returned results
        for (i = 0; i < (size_t)pci->retc; i++) {
@@ -1423,7 +1434,9 @@ wrapup:
        // clear any mprotected regions
        while (regions) {
                mprotected_region *next = regions->next;
-               clear_mprotect(regions->addr, regions->len);
+               if (regions->is_protected) {
+                       clear_mprotect(regions->addr, regions->len);    
+               }
                GDKfree(regions);
                regions = next;
        }
_______________________________________________
checkin-list mailing list
[email protected]
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to