This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit a7efd67b1d0b87975fdda7051ff48da4dcead9f0
Author: Mark Dokter <[email protected]>
AuthorDate: Mon Feb 22 16:00:42 2021 +0100

    [SYSTEMDS-2852] Improve SPOOF CUDA compilation
    
    Compilation is now invoked per code template separately and all necessary 
flags for the operator are passed.
    Also contains a bugfix for the full aggregation kernel template (wrong row 
index passed). All cellwise and rowwise template tests pass except for some 
runtime inaccuracies.
---
 src/main/cuda/headers/reduction.cuh                |  20 +-
 src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp  | 216 ++++++++----------
 src/main/cuda/spoof-launcher/SpoofCUDAContext.h    | 114 +++++-----
 src/main/cuda/spoof-launcher/jni_bridge.cpp        | 247 ++++++++++++++-------
 src/main/cuda/spoof-launcher/jni_bridge.h          |  16 ++
 src/main/cuda/spoof/cellwise.cu                    |   1 +
 src/main/java/org/apache/sysds/common/Types.java   |  37 ++-
 .../apache/sysds/hops/codegen/SpoofCompiler.java   |  18 +-
 .../apache/sysds/hops/codegen/cplan/CNodeCell.java |  11 +
 .../sysds/hops/codegen/cplan/CNodeMultiAgg.java    |  10 +
 .../hops/codegen/cplan/CNodeOuterProduct.java      |  10 +
 .../apache/sysds/hops/codegen/cplan/CNodeRow.java  |  17 +-
 .../apache/sysds/hops/codegen/cplan/CNodeTpl.java  |   2 +
 .../apache/sysds/runtime/codegen/SpoofCUDA.java    |   6 +-
 .../sysds/runtime/codegen/SpoofCellwise.java       |  31 ++-
 .../apache/sysds/runtime/codegen/SpoofRowwise.java |  47 +++-
 .../test/functions/codegen/CellwiseTmplTest.java   |  13 +-
 17 files changed, 497 insertions(+), 319 deletions(-)

diff --git a/src/main/cuda/headers/reduction.cuh 
b/src/main/cuda/headers/reduction.cuh
index 8a024cc..568eb12 100644
--- a/src/main/cuda/headers/reduction.cuh
+++ b/src/main/cuda/headers/reduction.cuh
@@ -66,7 +66,8 @@ __device__ void FULL_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t
        // number of active thread blocks (via gridDim).  More blocks will 
result
        // in a larger gridSize and therefore fewer elements per thread
        while (i < N) {
-               v = reduction_op(v, spoof_op(*(in->vals(i)), i, i / N, i % N));
+//             printf("tid=%d i=%d N=%d, in->cols()=%d rix=%d\n", threadIdx.x, 
i, N, in->cols(), i/in->cols());
+               v = reduction_op(v, spoof_op(*(in->vals(i)), i, i / in->cols(), 
i % in->cols()));
 
                if (i + blockDim.x < N) {
                        //__syncthreads();
@@ -106,7 +107,7 @@ __device__ void FULL_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t
                }
                __syncthreads();
        }
-
+       
        if (tid < 32) {
                // now that we are using warp-synchronous programming (below)
                // we need to declare our shared memory volatile so that the 
compiler
@@ -115,27 +116,40 @@ __device__ void FULL_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t
                if (blockDim.x >= 64) {
                        smem[tid] = v = reduction_op(v, smem[tid + 32]);
                }
+               if(tid<12)
+                       printf("bid=%d tid=%d reduction result: %3.1f\n", 
blockIdx.x, tid, sdata[tid]);
+               
                if (blockDim.x >= 32) {
                        smem[tid] = v = reduction_op(v, smem[tid + 16]);
                }
+//             if(tid==0)
+//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                if (blockDim.x >= 16) {
                        smem[tid] = v = reduction_op(v, smem[tid + 8]);
                }
+//             if(tid==0)
+//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                if (blockDim.x >= 8) {
                        smem[tid] = v = reduction_op(v, smem[tid + 4]);
                }
+//             if(tid==0)
+//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                if (blockDim.x >= 4) {
                        smem[tid] = v = reduction_op(v, smem[tid + 2]);
                }
+//             if(tid==0)
+//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                if (blockDim.x >= 2) {
                        smem[tid] = v = reduction_op(v, smem[tid + 1]);
                }
+//             if(tid==0)
+//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
        }
 
         // write result for this block to global mem
         if (tid == 0) {
 //             if(gridDim.x < 10)
-//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
+                       printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                out->val(0, blockIdx.x) = sdata[0];
         }
 }
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp 
b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
index d3f0778..3e023d5 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
@@ -24,49 +24,65 @@
 #include <cstdlib>
 #include <sstream>
 
+using clk = std::chrono::high_resolution_clock;
+using sec = std::chrono::duration<double, std::ratio<1>>;
+
 size_t SpoofCUDAContext::initialize_cuda(uint32_t device_id, const char* 
resource_path) {
 
-#ifdef __DEBUG
+#ifdef _DEBUG
        std::cout << "initializing cuda device " << device_id << std::endl;
 #endif
-
-  SpoofCUDAContext *ctx = new SpoofCUDAContext(resource_path);
-  // cuda device is handled by jCuda atm
-  //cudaSetDevice(device_id);
-  //cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync);
-  //cudaDeviceSynchronize();
-
-  CHECK_CUDA(cuModuleLoad(&(ctx->reductions), std::string(ctx->resource_path + 
std::string("/cuda/kernels/reduction.ptx")).c_str()));
-
-  CUfunction func;
-
-  // ToDo: implement a more scalable solution for these imports
-
-  // SUM
-  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_d"));
-  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_d", func));
-//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_f"));
-//  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_f", func));
-//
-//  // SUM_SQ
-//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_sq_d"));
-//  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_d", func));
-//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_sq_f"));
-//  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_f", func));
-//
-//  // MIN
-  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_d"));
-  ctx->reduction_kernels.insert(std::make_pair("reduce_min_d", func));
-//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_f"));
-//  ctx->reduction_kernels.insert(std::make_pair("reduce_min_f", func));
-//
-//  // MAX
-  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_d"));
-  ctx->reduction_kernels.insert(std::make_pair("reduce_max_d", func));
-//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_f"));
-//  ctx->reduction_kernels.insert(std::make_pair("reduce_max_f", func));
-
-  return reinterpret_cast<size_t>(ctx);
+       std::string cuda_include_path;
+       
+       char* cdp = std::getenv("CUDA_PATH");
+       if(cdp != nullptr)
+               cuda_include_path = std::string("-I") + std::string(cdp) + 
"/include";
+       else {
+               std::cout << "Warning: CUDA_PATH environment variable not set. 
Using default include path "
+                                        "/usr/local/cuda/include" << std::endl;
+               cuda_include_path = std::string("-I/usr/local/cuda/include");
+       }
+       
+       std::stringstream s1, s2;
+       s1 << "-I" << resource_path << "/cuda/headers";
+       s2 << "-I" << resource_path << "/cuda/spoof";
+       auto *ctx = new SpoofCUDAContext(resource_path, {s1.str(), s2.str(), 
cuda_include_path});
+       // cuda device is handled by jCuda atm
+       //cudaSetDevice(device_id);
+       //cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync);
+       //cudaDeviceSynchronize();
+       
+       CHECK_CUDA(cuModuleLoad(&(ctx->reductions), 
std::string(ctx->resource_path + 
std::string("/cuda/kernels/reduction.ptx")).c_str()));
+       
+       CUfunction func;
+       
+       // ToDo: implement a more scalable solution for these imports
+       
+       // SUM
+       CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_d"));
+       ctx->reduction_kernels.insert(std::make_pair("reduce_sum_d", func));
+       //  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, 
"reduce_sum_f"));
+       //  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_f", func));
+       //
+       //  // SUM_SQ
+       //  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, 
"reduce_sum_sq_d"));
+       //  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_d", 
func));
+       //  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, 
"reduce_sum_sq_f"));
+       //  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_f", 
func));
+       //
+       //  // MIN
+       CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_d"));
+       ctx->reduction_kernels.insert(std::make_pair("reduce_min_d", func));
+       //  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, 
"reduce_min_f"));
+       //  ctx->reduction_kernels.insert(std::make_pair("reduce_min_f", func));
+       //
+       //  // MAX
+       CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_d"));
+       ctx->reduction_kernels.insert(std::make_pair("reduce_max_d", func));
+       //  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, 
"reduce_max_f"));
+       //  ctx->reduction_kernels.insert(std::make_pair("reduce_max_f", func));
+
+       return reinterpret_cast<size_t>(ctx);
 }
 
 void SpoofCUDAContext::destroy_cuda(SpoofCUDAContext *ctx, uint32_t device_id) 
{
@@ -76,99 +92,43 @@ void SpoofCUDAContext::destroy_cuda(SpoofCUDAContext *ctx, 
uint32_t device_id) {
   //cudaDeviceReset();
 }
 
-bool SpoofCUDAContext::compile_cuda(const std::string &src,
-                                    const std::string &name) {
-    std::string cuda_include_path("");
-    char* cdp = std::getenv("CUDA_PATH");
-    if(cdp != nullptr)
-        cuda_include_path = std::string("-I") + std::string(cdp) + "/include";
-    else {
-       std::cout << "Warning: CUDA_PATH environment variable not set. Using 
default include path"
-                       "/usr/local/cuda/include" << std::endl;
-       cuda_include_path = std::string("-I/usr/local/cuda/include");
-    }
-
-#ifdef __DEBUG
-  std::cout << "compiling cuda kernel " << name << std::endl;
-  std::cout << src << std::endl;
-  std::cout << "cwd: " << std::filesystem::current_path() << std::endl;
-  std::cout << "cuda_path: " << cuda_include_path << std::endl;
+int SpoofCUDAContext::compile(const std::string &src, const std::string &name, 
SpoofOperator::OpType op_type,
+               SpoofOperator::AggType agg_type, SpoofOperator::AggOp agg_op, 
SpoofOperator::RowType row_type, bool sparse_safe,
+               int32_t const_dim2, uint32_t num_vectors, bool TB1) {
+#ifdef _DEBUG
+//     std::cout << "---=== START source listing of spoof cuda kernel [ " << 
name << " ]: " << std::endl;
+//    uint32_t line_num = 0;
+//     std::istringstream src_stream(src);
+//    for(std::string line; std::getline(src_stream, line); line_num++)
+//             std::cout << line_num << ": " << line << std::endl;
+//
+//     std::cout << "---=== END source listing of spoof cuda kernel [ " << 
name << " ]." << std::endl;
+       std::cout << "cwd: " << std::filesystem::current_path() << std::endl;
+       std::cout << "include_paths: ";
+       for_each (include_paths.begin(), include_paths.end(), [](const 
std::string& line){ std::cout << line << '\n';});
+       std::cout << std::endl;
 #endif
+// uncomment all related lines for temporary timing output:
+//     auto start = clk::now();
 
-       SpoofOperator::AggType agg_type= SpoofOperator::AggType::NONE;
-       SpoofOperator::AggOp agg_op = SpoofOperator::AggOp::NONE;
-       SpoofOperator::OpType op_type = SpoofOperator::OpType::NONE;
+//     auto compile_start = clk::now();
+       jitify::Program program = kernel_cache.program(src, 0, include_paths);
+//     auto compile_end = clk::now();
+//     auto compile_duration = std::chrono::duration_cast<sec>(compile_end - 
compile_start).count();
+       ops.insert(std::make_pair(name, SpoofOperator({std::move(program),  
op_type, agg_type, agg_op, row_type, name,
+                       const_dim2, num_vectors, TB1, sparse_safe})));
 
-       auto pos = 0;
-       
-       if((pos = src.find("SpoofCellwiseOp")) != std::string::npos)
-               op_type = SpoofOperator::OpType::CW;
-    else if((pos = src.find("SpoofRowwiseOp")) != std::string::npos)
-               op_type = SpoofOperator::OpType::RA;
-       else {
-        std::cerr << "error: unknown spoof operator" << std::endl;
-               return false;
-       }
+//     auto end = clk::now();
 
-       bool TB1 = false;
-       if((pos = src.find("TB1")) != std::string::npos)
-               if(src.substr(pos, pos+8).find("true") != std::string::npos)
-                       TB1 = true;
+//     auto handling_duration = std::chrono::duration_cast<sec>(end - 
start).count() - compile_duration;
 
-       uint32_t numTempVect = 0;
-       if((pos = src.find("// VectMem: ")) != std::string::npos)
-               numTempVect = std::stoi(std::string(src.begin() + pos + 12, 
std::find(src.begin()+pos, src.end(), '\n')));
-
-       int32_t constDim2 = 0;
-       if((pos = src.find("// ConstDim2: ")) != std::string::npos)
-               constDim2 = std::stoi(std::string(src.begin() + pos + 14, 
std::find(src.begin()+pos, src.end(), '\n')));
-
-       bool sparse_safe = false;
-       if ((pos = src.find("// SparseSafe:")) != std::string::npos)
-               if (src.substr(pos, pos + 15).find("true") != std::string::npos)
-                       sparse_safe = true;
-
-       if(((pos = src.find("CellType")) != std::string::npos) || ((pos = 
src.find("RowType")) != std::string::npos)){
-               if(src.substr(pos, pos+30).find("FULL_AGG") != 
std::string::npos)
-                       agg_type= SpoofOperator::AggType::FULL_AGG;
-               else if(src.substr(pos, pos+30).find("ROW_AGG") != 
std::string::npos)
-                       agg_type= SpoofOperator::AggType::ROW_AGG;
-               else if(src.substr(pos, pos+30).find("COL_AGG") != 
std::string::npos)
-                       agg_type= SpoofOperator::AggType::COL_AGG;
-               else if(src.substr(pos, pos+30).find("NO_AGG") != 
std::string::npos)
-                       agg_type= SpoofOperator::AggType::NO_AGG;
-               else if(src.substr(pos, pos+30).find("NO_AGG_CONST") != 
std::string::npos)
-                       agg_type= SpoofOperator::AggType::NO_AGG_CONST;
-               else if(src.substr(pos, pos+30).find("COL_AGG_T") != 
std::string::npos)
-                       agg_type= SpoofOperator::AggType::COL_AGG_T;
-               else {
-                       std::cerr << "error: unknown aggregation type" << 
std::endl;
-                       return false;
-               }
-
-               if((agg_type!= SpoofOperator::AggType::NO_AGG) && (op_type == 
SpoofOperator::OpType::CW)) {
-                       if((pos = src.find("AggOp")) != std::string::npos) {
-                               if(src.substr(pos, pos+30).find("AggOp.SUM") != 
std::string::npos)
-                                       agg_op = SpoofOperator::AggOp::SUM;
-                               else if(src.substr(pos, 
pos+30).find("AggOp.SUM_SQ") != std::string::npos)
-                                       agg_op = SpoofOperator::AggOp::SUM_SQ;
-                               else if(src.substr(pos, 
pos+30).find("AggOp.MIN") != std::string::npos)
-                                       agg_op = SpoofOperator::AggOp::MIN;
-                               else if(src.substr(pos, 
pos+30).find("AggOp.MAX") != std::string::npos)
-                                       agg_op = SpoofOperator::AggOp::MAX;
-                               else {
-                               std::cerr << "error: unknown aggregation 
operator" << std::endl;
-                                       return false;
-                               }
-                       }
-               }
-       }
-
-       std::stringstream s1, s2, s3;
-       s1 << "-I" << resource_path << "/cuda/headers";
-       s2 << "-I" << resource_path << "/cuda/spoof";
+//     compile_total += compile_duration;
+//     handling_total += handling_duration;
 
-       jitify::Program program = kernel_cache.program(src, 0, {s1.str(), 
s2.str(), cuda_include_path});
-       ops.insert(std::make_pair(name, SpoofOperator({std::move(program), 
agg_type, agg_op, op_type, name, TB1, constDim2, numTempVect, sparse_safe})));
-       return true;
+//     std::cout << name << " times [s] handling/compile/totals(h/c)/count: "
+//                     << handling_duration << "/"
+//                     << compile_duration << "/"
+//                     << handling_total << "/"
+//                     << compile_total << "/" << compile_count + 1 << 
std::endl;
+       return compile_count++;
 }
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.h 
b/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
index a9d3562..89dd9e0 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
@@ -21,7 +21,9 @@
 #ifndef SPOOFCUDACONTEXT_H
 #define SPOOFCUDACONTEXT_H
 
-#define NOMINMAX
+#if defined(_WIN32) || defined(_WIN64)
+       #define NOMINMAX
+#endif
 
 #include <cmath>
 #include <cstdint>
@@ -39,6 +41,7 @@
 #endif
 
 #include <jitify.hpp>
+#include <utility>
 #include <cublas_v2.h>
 
 #include "host_utils.h"
@@ -46,46 +49,57 @@
 using jitify::reflection::type_of;
 
 struct SpoofOperator {
-       enum class AggType : int { NO_AGG, NO_AGG_CONST, ROW_AGG, COL_AGG, 
FULL_AGG, COL_AGG_T, NONE };
-       enum class AggOp : int {SUM, SUM_SQ, MIN, MAX, NONE };
        enum class OpType : int { CW, RA, MA, OP, NONE };
+       enum class AggType : int { NONE, NO_AGG, FULL_AGG, ROW_AGG, COL_AGG };
+       enum class AggOp : int { NONE, SUM, SUM_SQ, MIN, MAX };
+       enum class RowType : int { NONE, FULL_AGG = 4 };
        
        jitify::Program program;
+       OpType op_type;
        AggType agg_type;
        AggOp agg_op;
-       OpType op_type;
+       RowType row_type;
     const std::string name;
-       bool TB1 = false;
        int32_t const_dim2;
        uint32_t num_temp_vectors;
+       bool TB1 = false;
        bool sparse_safe = true;
 };
 
 class SpoofCUDAContext {
-
-  jitify::JitCache kernel_cache;
-  std::map<const std::string, SpoofOperator> ops;
-  CUmodule reductions;
-  std::map<const std::string, CUfunction> reduction_kernels;
-
+       jitify::JitCache kernel_cache;
+       std::map<const std::string, SpoofOperator> ops;
+       CUmodule reductions;
+       std::map<const std::string, CUfunction> reduction_kernels;
+       double handling_total, compile_total;
+       uint32_t compile_count;
+       
+       const std::string resource_path;
+       const std::vector<std::string> include_paths;
+       
 public:
-  // ToDo: make launch config more adaptive
-  // num threads
-  const int NT = 256;
-
-  // values / thread
-  const int VT = 4;
+       // ToDo: make launch config more adaptive
+       // num threads
+       const int NT = 256;
 
-  const std::string resource_path;
+       // values / thread
+       const int VT = 4;
 
-  SpoofCUDAContext(const char* resource_path_) : reductions(nullptr), 
resource_path(resource_path_) {}
+       explicit SpoofCUDAContext(const char* resource_path_, 
std::vector<std::string>  include_paths_) : reductions(nullptr),
+                       resource_path(resource_path_), 
include_paths(std::move(include_paths_)), handling_total(0.0), 
compile_total(0.0),
+                       compile_count(0) {}
 
-  static size_t initialize_cuda(uint32_t device_id, const char* 
resource_path_);
+       static size_t initialize_cuda(uint32_t device_id, const char* 
resource_path_);
 
-  static void destroy_cuda(SpoofCUDAContext *ctx, uint32_t device_id);
+       static void destroy_cuda(SpoofCUDAContext *ctx, uint32_t device_id);
 
-  bool compile_cuda(const std::string &src, const std::string &name);
+       int compile(const std::string &src, const std::string &name, 
SpoofOperator::OpType op_type,
+                       SpoofOperator::AggType agg_type = 
SpoofOperator::AggType::NONE,
+                       SpoofOperator::AggOp agg_op = 
SpoofOperator::AggOp::NONE,
+                       SpoofOperator::RowType row_type = 
SpoofOperator::RowType::NONE, bool sparse_safe = true,
+                       int32_t const_dim2 = -1, uint32_t num_vectors = 0, bool 
TB1 = false);
 
+       
        template <typename T>
        T execute_kernel(const std::string &name, std::vector<Matrix<T>>& 
input,  std::vector<Matrix<T>>& sides, Matrix<T>* output,
                                         T *scalars_ptr, uint32_t num_scalars, 
uint32_t grix) {
@@ -107,8 +121,8 @@ public:
                        CHECK_CUDART(cudaMemcpy(d_in, 
reinterpret_cast<void*>(&input[0]), sizeof(Matrix<T>), cudaMemcpyHostToDevice));
 
                        if(output != nullptr) {
-                               if (op->sparse_safe == true && 
input.front().row_ptr != nullptr) {
-#ifdef __DEBUG
+                               if (op->sparse_safe && input.front().row_ptr != 
nullptr) {
+#ifdef _DEBUG
                                        std::cout << "copying sparse safe row 
ptrs" << std::endl;
 #endif
                                        
CHECK_CUDART(cudaMemcpy(output->row_ptr, input.front().row_ptr, 
(input.front().rows+1)*sizeof(uint32_t), cudaMemcpyDeviceToDevice));
@@ -135,7 +149,7 @@ public:
                        
                        if (!sides.empty()) {
                                if(op->TB1) {
-#ifdef __DEBUG
+#ifdef _DEBUG
                                        std::cout << "transposing TB1 for " << 
op->name << std::endl;
 #endif
                                        T* b1 = sides[0].data;
@@ -175,8 +189,7 @@ public:
                                                                 
input.front().cols, grix, input[0].row_ptr!=nullptr);
                        break;
                    default:
-                               std::cerr << "error: unknown spoof operator" << 
std::endl;
-                       return result;
+                               throw std::runtime_error("error: unknown spoof 
operator");
                    }
                        
                        if (num_scalars > 0)
@@ -188,9 +201,7 @@ public:
                        if (op->TB1)
                                CHECK_CUDART(cudaFree(b1_transposed));
                        
-                       if (op->agg_type == SpoofOperator::AggType::FULL_AGG) {
-//                             std::cout << "retrieving scalar result" << 
std::endl;
-                               
+                       if (op->agg_type == SpoofOperator::AggType::FULL_AGG || 
op->row_type == SpoofOperator::RowType::FULL_AGG) {
                                Matrix<T> res_mat;
                                CHECK_CUDART(cudaMemcpy(&res_mat, d_out, 
sizeof(Matrix<T>), cudaMemcpyDeviceToHost));
                                CHECK_CUDART(cudaMemcpy(&result, res_mat.data, 
sizeof(T), cudaMemcpyDeviceToHost));
@@ -200,8 +211,7 @@ public:
                        }
                } 
                else {
-                       std::cerr << "kernel " << name << " not found." << 
std::endl;
-                       return result;
+                       throw std::runtime_error("kernel " + name + " not 
found.");
                }
                return result;
        }
@@ -222,7 +232,7 @@ public:
                  reduction_type = "_col_";
                  break;
          default:
-                 std::cerr << "unknown reduction type" << std::endl;
+                 throw std::runtime_error("unknown reduction type");
                  return "";
          }
        
@@ -240,8 +250,7 @@ public:
                  reduction_kernel_name = "reduce" + reduction_type + "sum" + 
suffix;
                  break;
          default:
-                 std::cerr << "unknown reduction op" << std::endl;
-                 return "";
+                 throw std::runtime_error("unknown reduction op");
          }
        
          return reduction_kernel_name;
@@ -267,7 +276,7 @@ public:
                        dim3 block(NT, 1, 1);
                        uint32_t shared_mem_size = NT * sizeof(T);
 
-//#ifdef __DEBUG
+#ifdef _DEBUG
                                // ToDo: connect output to SystemDS logging 
facilities
                                std::cout << "launching spoof cellwise kernel " 
<< op_name << " with "
                                                  << NT * NB << " threads in " 
<< NB << " blocks and "
@@ -275,7 +284,7 @@ public:
                                                  << " bytes of shared memory 
for full aggregation of "
                                                  << N << " elements"
                                                  << std::endl;
-//#endif
+#endif
                        
                        CHECK_CUDA(op->program.kernel(op_name)
                                                           
.instantiate(type_of(value_type), std::max(1u, num_sides))
@@ -284,7 +293,6 @@ public:
                        
                        if(NB > 1) {
                                std::string reduction_kernel_name = 
determine_agg_kernel<T>(op);
-
                                CUfunction reduce_kernel = 
reduction_kernels.find(reduction_kernel_name)->second;
                                N = NB;
                                uint32_t iter = 1;
@@ -292,18 +300,18 @@ public:
                                        void* args[3] = { &d_out, &d_out, &N};
 
                                        NB = std::ceil((N + NT * 2 - 1) / (NT * 
2));
-//#ifdef __DEBUG
+#ifdef _DEBUG
                                        std::cout << "agg iter " << iter++ << " 
launching spoof cellwise kernel " << op_name << " with "
                     << NT * NB << " threads in " << NB << " blocks and "
                     << shared_mem_size
                     << " bytes of shared memory for full aggregation of "
                     << N << " elements"
                     << std::endl;
-//#endif
+#endif
                                        CHECK_CUDA(cuLaunchKernel(reduce_kernel,
                                                        NB, 1, 1,
                                                        NT, 1, 1,
-                                                       shared_mem_size, 0, 
args, 0));
+                                                       shared_mem_size, 
nullptr, args, nullptr));
                                                        N = NB;
                                }
                        }
@@ -315,7 +323,7 @@ public:
                        dim3 grid(NB, 1, 1);
                        dim3 block(NT, 1, 1);
                        uint32_t shared_mem_size = 0;
-#ifdef __DEBUG
+#ifdef _DEBUG
                        std::cout << " launching spoof cellwise kernel " << 
op_name << " with "
                                        << NT * NB << " threads in " << NB << " 
blocks for column aggregation of "
                                        << N << " elements" << std::endl;
@@ -333,12 +341,12 @@ public:
                        dim3 grid(NB, 1, 1);
                        dim3 block(NT, 1, 1);
                        uint32_t shared_mem_size = NT * sizeof(T);
-//#ifdef __DEBUG
+#ifdef _DEBUG
                        std::cout << " launching spoof cellwise kernel " << 
op_name << " with "
                                        << NT * NB << " threads in " << NB << " 
blocks and "
                                        << shared_mem_size << " bytes of shared 
memory for row aggregation of "
                                        << N << " elements" << std::endl;
-//#endif
+#endif
                        CHECK_CUDA(op->program.kernel(op_name)
                                                           
.instantiate(type_of(value_type), std::max(1u, num_sides))
                                                           .configure(grid, 
block, shared_mem_size)
@@ -357,7 +365,7 @@ public:
                        dim3 block(NT, 1, 1);
                        uint32_t shared_mem_size = 0;
 
-#ifdef __DEBUG
+#ifdef _DEBUG
                        if(sparse) {
                                std::cout << "launching sparse spoof cellwise 
kernel " << op_name << " with " << NT * NB
                                                  << " threads in " << NB << " 
blocks without aggregation for " << N << " elements"
@@ -369,10 +377,6 @@ public:
                                                  << std::endl;
                        }
 #endif
-//                     CHECK_CUDA(op->program.kernel(op_name)
-//                                     .instantiate(type_of(result))
-//                                     .configure(grid, block)
-//                                     .launch(in_ptrs[0], d_sides, out_ptr, 
d_scalars, m, n, grix));
 
                        CHECK_CUDA(op->program.kernel(op_name)
                                                           
.instantiate(type_of(value_type), std::max(1u, num_sides))
@@ -404,12 +408,12 @@ public:
                if(sparse)
                        name = std::string(op->name + "_SPARSE");
 
-//#ifdef __DEBUG
-                       // ToDo: connect output to SystemDS logging facilities
-                       std::cout << "launching spoof rowwise kernel " << name 
<< " with " << NT * in_rows << " threads in " << in_rows
-                               << " blocks and " << shared_mem_size << " bytes 
of shared memory for " << in_cols << " cols processed by "
-                               << NT << " threads per row, adding " << 
temp_buf_size / 1024 << " kb of temp buffer in global memory." <<  std::endl;
-//#endif
+#ifdef _DEBUG
+               // ToDo: connect output to SystemDS logging facilities
+               std::cout << "launching spoof rowwise kernel " << name << " 
with " << NT * in_rows << " threads in " << in_rows
+                       << " blocks and " << shared_mem_size << " bytes of 
shared memory for " << in_cols << " cols processed by "
+                       << NT << " threads per row, adding " << temp_buf_size / 
1024 << " kb of temp buffer in global memory." <<  std::endl;
+#endif
                CHECK_CUDA(op->program.kernel(name)
                                .instantiate(type_of(value_type), std::max(1u, 
num_sides), op->num_temp_vectors, tmp_len)
                                .configure(grid, block, shared_mem_size)
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.cpp 
b/src/main/cuda/spoof-launcher/jni_bridge.cpp
index af6967d..fd95d1b 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.cpp
+++ b/src/main/cuda/spoof-launcher/jni_bridge.cpp
@@ -26,10 +26,42 @@
 
 #define RELEASE_ARRAY(env, java, cpp)(env->ReleasePrimitiveArrayCritical(java, 
cpp, 0))
 
+// error output helper
+void printStdException(JNIEnv *env, jstring name, const std::exception& e, 
bool compile = false) {
+       std::string type = compile ? "compiling" : "executing";
+       if(name != nullptr) {
+               const char *cstr_name = env->GetStringUTFChars(name, nullptr);
+               std::cout << "std::exception while " << type << "  SPOOF CUDA 
operator " << cstr_name << ":\n" << e.what() <<
+                                 std::endl;
+               env->ReleaseStringUTFChars(name, cstr_name);
+       }
+       else
+               std::cout << "std::exception while " << type << " SPOOF CUDA 
operator (name=nullptr):\n" << e.what() <<
+                               std::endl;
+}
+
+void printException(JNIEnv* env, jstring name, bool compile = false) {
+       std::string type = compile ? "compiling" : "executing";
+       if(name != nullptr) {
+               const char *cstr_name = env->GetStringUTFChars(name, nullptr);
+               std::cout << "Unknown exception occurred while " << type << " 
SPOOF CUDA operator " << cstr_name << std::endl;
+               env->ReleaseStringUTFChars(name, cstr_name);
+       }
+       else
+               std::cout << "Unknown exception occurred while " << type << " 
SPOOF CUDA operator (name=nullptr)" << std::endl;
+}
+
+void printSource(JNIEnv* env, jstring name, jstring src) {
+       if(src != nullptr) {
+               const char *cstr_src = env->GetStringUTFChars(src, nullptr);
+               std::cout << "Source code:\n" << cstr_src << std::endl;
+               env->ReleaseStringUTFChars(name, cstr_src);
+       }
+}
+
 JNIEXPORT jlong JNICALL
 Java_org_apache_sysds_hops_codegen_SpoofCompiler_initialize_1cuda_1context(
     JNIEnv *env, jobject jobj, jint device_id, jstring resource_path) {
-
   const char *cstr_rp = env->GetStringUTFChars(resource_path, nullptr);
   size_t ctx = SpoofCUDAContext::initialize_cuda(device_id, cstr_rp);
   env->ReleaseStringUTFChars(resource_path, cstr_rp);
@@ -42,54 +74,99 @@ 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context(
   SpoofCUDAContext::destroy_cuda(reinterpret_cast<SpoofCUDAContext *>(ctx), 
device_id);
 }
 
-JNIEXPORT jboolean JNICALL
-Java_org_apache_sysds_hops_codegen_SpoofCompiler_compile_1cuda_1kernel(
-    JNIEnv *env, jobject jobj, jlong ctx, jstring name, jstring src) {
-  auto *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx);
-  const char *cstr_name = env->GetStringUTFChars(name, nullptr);
-  const char *cstr_src = env->GetStringUTFChars(src, nullptr);
-  bool result = ctx_->compile_cuda(cstr_src, cstr_name);
-  env->ReleaseStringUTFChars(src, cstr_src);
-  env->ReleaseStringUTFChars(name, cstr_name);
-  return result;
+JNIEXPORT jint JNICALL 
Java_org_apache_sysds_hops_codegen_cplan_CNodeCell_compile_1nvrtc
+               (JNIEnv *env, jobject jobj, jlong ctx, jstring name, jstring 
src, jint type, jint agg_op, jboolean sparseSafe) {
+       try {
+               auto *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx);
+               const char *cstr_name = env->GetStringUTFChars(name, nullptr);
+               const char *cstr_src = env->GetStringUTFChars(src, nullptr);
+               
+               int result = ctx_->compile(cstr_src, cstr_name, 
SpoofOperator::OpType::CW,
+                               SpoofOperator::AggType(type), 
SpoofOperator::AggOp(agg_op), SpoofOperator::RowType::NONE, sparseSafe);
+               
+               env->ReleaseStringUTFChars(src, cstr_src);
+               env->ReleaseStringUTFChars(name, cstr_name);
+               return result;
+       }
+       catch (std::exception& e) {
+               printStdException(env, name, e, true);
+       }
+       catch (...) {
+               printException(env, name, true);
+       }
+       printSource(env, name, src);
+       return -1;
+}
+
+/*
+ * Class:     org_apache_sysds_hops_codegen_cplan_CNodeRow
+ * Method:    compile_nvrtc
+ * Signature: (JLjava/lang/String;Ljava/lang/String;IIIZ)I
+ */
+JNIEXPORT jint JNICALL 
Java_org_apache_sysds_hops_codegen_cplan_CNodeRow_compile_1nvrtc
+               (JNIEnv *env, jobject jobj, jlong ctx, jstring name, jstring 
src, jint type, jint const_dim2, jint num_vectors,
+               jboolean TB1) {
+       try {
+               auto *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx);
+               const char *cstr_name = env->GetStringUTFChars(name, nullptr);
+               const char *cstr_src = env->GetStringUTFChars(src, nullptr);
+               
+               int result = ctx_->compile(cstr_src, cstr_name, 
SpoofOperator::OpType::RA,
+                               SpoofOperator::AggType::NONE, 
SpoofOperator::AggOp::NONE,
+                               SpoofOperator::RowType(type), true, const_dim2, 
num_vectors, TB1);
+               
+               env->ReleaseStringUTFChars(src, cstr_src);
+               env->ReleaseStringUTFChars(name, cstr_name);
+               return result;
+       }
+       catch (std::exception& e) {
+               printStdException(env, name, e);
+       }
+       catch (...) {
+               printException(env, name);
+       }
+       printSource(env, name, src);
+       return -1;
 }
 
+
 JNIEXPORT jdouble JNICALL
 Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1d(
     JNIEnv *env, jobject jobj, jlong ctx, jstring name, jlongArray in_ptrs, 
jint input_offset, jlongArray side_ptrs,
                jlongArray out_ptrs, jdoubleArray scalars_, jlong grix, jobject 
inputs_, jobject out_obj) {
-
-       auto *ctx_ = reinterpret_cast<SpoofCUDAContext*>(ctx);
-       const char *cstr_name = env->GetStringUTFChars(name, nullptr);
-
-       auto* inputs = reinterpret_cast<size_t*>(GET_ARRAY(env, in_ptrs));
-       auto* sides = reinterpret_cast<size_t*>(GET_ARRAY(env, side_ptrs));
-       auto *output = reinterpret_cast<size_t*>(GET_ARRAY(env, out_ptrs));
-       auto *scalars = reinterpret_cast<double*>(GET_ARRAY(env, scalars_));
-
-       //ToDo: call once while init
-       jclass CacheableData = 
env->FindClass("org/apache/sysds/runtime/controlprogram/caching/CacheableData");
-       if(!CacheableData) {
-               std::cerr << " JNIEnv -> FindClass(CacheableData) failed" << 
std::endl;
-               return -1.0;
-       }
-       jclass ArrayList = env->FindClass("java/util/ArrayList");
-       if(!ArrayList) {
-               std::cerr << " JNIEnv -> FindClass(ArrayList) failed" << 
std::endl;
-               return -1.0;
-       }
-       jmethodID mat_obj_num_rows = env->GetMethodID(CacheableData, 
"getNumRows", "()J");
-       if(!mat_obj_num_rows) {
-               std::cerr << " JNIEnv -> GetMethodID() failed" << std::endl;
-               return -1.0;
-       }
-       jmethodID mat_obj_num_cols = env->GetMethodID(CacheableData, 
"getNumColumns", "()J");
-       if(!mat_obj_num_cols) {
-               std::cerr << " JNIEnv -> GetMethodID() failed" << std::endl;
-               return -1.0;
-       }
-       jmethodID ArrayList_size = env->GetMethodID(ArrayList, "size", "()I");
-       jmethodID ArrayList_get = env->GetMethodID(ArrayList, "get", 
"(I)Ljava/lang/Object;");
+       double result = 0.0;
+       try {
+           auto *ctx_ = reinterpret_cast<SpoofCUDAContext*>(ctx);
+           const char *cstr_name = env->GetStringUTFChars(name, nullptr);
+           
+               auto* inputs = reinterpret_cast<size_t*>(GET_ARRAY(env, 
in_ptrs));
+               auto* sides = reinterpret_cast<size_t*>(GET_ARRAY(env, 
side_ptrs));
+               auto *output = reinterpret_cast<size_t*>(GET_ARRAY(env, 
out_ptrs));
+               auto *scalars = reinterpret_cast<double*>(GET_ARRAY(env, 
scalars_));
+
+               //ToDo: call once while init
+               jclass CacheableData = 
env->FindClass("org/apache/sysds/runtime/controlprogram/caching/CacheableData");
+               if (!CacheableData) {
+                       std::cerr << " JNIEnv -> FindClass(CacheableData) 
failed" << std::endl;
+                       return -1.0;
+               }
+               jclass ArrayList = env->FindClass("java/util/ArrayList");
+               if (!ArrayList) {
+                       std::cerr << " JNIEnv -> FindClass(ArrayList) failed" 
<< std::endl;
+                       return -1.0;
+               }
+               jmethodID mat_obj_num_rows = env->GetMethodID(CacheableData, 
"getNumRows", "()J");
+               if (!mat_obj_num_rows) {
+                       std::cerr << " JNIEnv -> GetMethodID() failed" << 
std::endl;
+                       return -1.0;
+               }
+               jmethodID mat_obj_num_cols = env->GetMethodID(CacheableData, 
"getNumColumns", "()J");
+               if (!mat_obj_num_cols) {
+                       std::cerr << " JNIEnv -> GetMethodID() failed" << 
std::endl;
+                       return -1.0;
+               }
+               jmethodID ArrayList_size = env->GetMethodID(ArrayList, "size", 
"()I");
+               jmethodID ArrayList_get = env->GetMethodID(ArrayList, "get", 
"(I)Ljava/lang/Object;");
 
                std::vector<Matrix<double>> in;
                jint num_inputs = env->CallIntMethod(inputs_, ArrayList_size);
@@ -97,50 +174,58 @@ Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1d(
                std::cout << "num inputs: " << num_inputs << " offsets: " << 
input_offset << std::endl;
 #endif
 
-       for(auto ptr_idx = 0, input_idx = 0; input_idx < input_offset; 
ptr_idx+=4, input_idx++) {
-               jobject input_obj = env->CallObjectMethod(inputs_, 
ArrayList_get, input_idx);
-               auto m = static_cast<uint32_t>(env->CallIntMethod(input_obj, 
mat_obj_num_rows));
-               auto n = static_cast<uint32_t>(env->CallIntMethod(input_obj, 
mat_obj_num_cols));
+        for(auto ptr_idx = 0, input_idx = 0; input_idx < input_offset; 
ptr_idx+=4, input_idx++) {
+            jobject input_obj = env->CallObjectMethod(inputs_, ArrayList_get, 
input_idx);
+            auto m = static_cast<uint32_t>(env->CallIntMethod(input_obj, 
mat_obj_num_rows));
+            auto n = static_cast<uint32_t>(env->CallIntMethod(input_obj, 
mat_obj_num_cols));
 
-               in.push_back(Matrix<double>{reinterpret_cast<double 
*>(inputs[ptr_idx + 3]),
-                                                                       
reinterpret_cast<uint32_t *>(inputs[ptr_idx + 1]),
-                                                                       
reinterpret_cast<uint32_t *>(inputs[ptr_idx + 2]),
-                                                                       m, n, 
static_cast<uint32_t>(inputs[ptr_idx])});
+            in.push_back(Matrix<double>{reinterpret_cast<double 
*>(inputs[ptr_idx + 3]),
+                                        reinterpret_cast<uint32_t 
*>(inputs[ptr_idx + 1]),
+                                        reinterpret_cast<uint32_t 
*>(inputs[ptr_idx + 2]),
+                                        m, n, 
static_cast<uint32_t>(inputs[ptr_idx])});
 #ifdef _DEBUG
                        std::cout << "input #" << input_idx << " m=" << m << " 
n=" << n << std::endl;
 #endif
+               }
+
+        std::vector<Matrix<double>> side_inputs;
+        for(uint32_t ptr_idx = 0, input_idx = input_offset; input_idx < 
num_inputs; ptr_idx+=4, input_idx++) {
+            jobject side_input_obj = env->CallObjectMethod(inputs_, 
ArrayList_get, input_idx);
+            auto m = static_cast<uint32_t>(env->CallIntMethod(side_input_obj, 
mat_obj_num_rows));
+            auto n = static_cast<uint32_t>(env->CallIntMethod(side_input_obj, 
mat_obj_num_cols));
+
+            side_inputs.push_back(Matrix<double>{reinterpret_cast<double 
*>(sides[ptr_idx + 3]),
+                                                 reinterpret_cast<uint32_t 
*>(sides[ptr_idx + 1]),
+                                                 reinterpret_cast<uint32_t 
*>(sides[ptr_idx + 2]),
+                                                 m, n, 
static_cast<uint32_t>(sides[ptr_idx])});
+        }
+
+        std::unique_ptr<Matrix<double>> out;
+        if(out_obj != nullptr) {
+            out = 
std::make_unique<Matrix<double>>(Matrix<double>{reinterpret_cast<double*>(output[3]),
+                                                    
reinterpret_cast<uint32_t*>(output[1]),
+                                                    
reinterpret_cast<uint32_t*>(output[2]),
+                                                    
static_cast<uint32_t>(env->CallIntMethod(out_obj, mat_obj_num_rows)),
+                                                    
static_cast<uint32_t>(env->CallIntMethod(out_obj, mat_obj_num_cols)),
+                                                    
static_cast<uint32_t>(output[0])});
+        }
+
+               result = ctx_->execute_kernel(cstr_name, in, side_inputs, 
out.get(), scalars,
+                                                                               
         env->GetArrayLength(scalars_), grix);
+
+               RELEASE_ARRAY(env, in_ptrs, inputs);
+               RELEASE_ARRAY(env, side_ptrs, sides);
+               RELEASE_ARRAY(env, out_ptrs, output);
+               RELEASE_ARRAY(env, scalars_, scalars);
+               
+               env->ReleaseStringUTFChars(name, cstr_name);
+               return result;
        }
-
-       std::vector<Matrix<double>> side_inputs;
-       for(uint32_t ptr_idx = 0, input_idx = input_offset; input_idx < 
num_inputs; ptr_idx+=4, input_idx++) {
-               jobject side_input_obj = env->CallObjectMethod(inputs_, 
ArrayList_get, input_idx);
-               auto m = 
static_cast<uint32_t>(env->CallIntMethod(side_input_obj, mat_obj_num_rows));
-               auto n = 
static_cast<uint32_t>(env->CallIntMethod(side_input_obj, mat_obj_num_cols));
-
-               side_inputs.push_back(Matrix<double>{reinterpret_cast<double 
*>(sides[ptr_idx + 3]),
-                                                                               
         reinterpret_cast<uint32_t *>(sides[ptr_idx + 1]),
-                                                                               
         reinterpret_cast<uint32_t *>(sides[ptr_idx + 2]),
-                                                                               
         m, n, static_cast<uint32_t>(sides[ptr_idx])});
+       catch (std::exception& e) {
+               printStdException(env, name, e);
        }
-
-       std::unique_ptr<Matrix<double>> out;
-       if(out_obj != nullptr) {
-               out = 
std::make_unique<Matrix<double>>(Matrix<double>{reinterpret_cast<double*>(output[3]),
-                                                                               
                reinterpret_cast<uint32_t*>(output[1]),
-                                                                               
                reinterpret_cast<uint32_t*>(output[2]),
-                                                                               
                static_cast<uint32_t>(env->CallIntMethod(out_obj, 
mat_obj_num_rows)),
-                                                                               
                static_cast<uint32_t>(env->CallIntMethod(out_obj, 
mat_obj_num_cols)),
-                                                                               
                static_cast<uint32_t>(output[0])});
+       catch (...) {
+               printException(env, name);
        }
-
-       double result = ctx_->execute_kernel(cstr_name, in, side_inputs, 
out.get(), scalars, env->GetArrayLength(scalars_),
-                                                                               
        grix);
-
-       RELEASE_ARRAY(env, in_ptrs, inputs);
-       RELEASE_ARRAY(env, side_ptrs, sides);
-       RELEASE_ARRAY(env, out_ptrs, output);
-       RELEASE_ARRAY(env, scalars_, scalars);
-
-       env->ReleaseStringUTFChars(name, cstr_name);
        return result;
 }
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.h 
b/src/main/cuda/spoof-launcher/jni_bridge.h
index 788f93e..48c4882 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.h
+++ b/src/main/cuda/spoof-launcher/jni_bridge.h
@@ -48,6 +48,22 @@ 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context(
     JNIEnv *, jobject, jlong, jint);
 
 /*
+ * Class:     org_apache_sysds_hops_codegen_cplan_CNodeCell
+ * Method:    compile_nvrtc
+ * Signature: (JLjava/lang/String;Ljava/lang/String;IIZ)I
+ */
+JNIEXPORT jint JNICALL 
Java_org_apache_sysds_hops_codegen_cplan_CNodeCell_compile_1nvrtc
+               (JNIEnv *, jobject, jlong, jstring, jstring, jint, jint, 
jboolean);
+
+/*
+ * Class:     org_apache_sysds_hops_codegen_cplan_CNodeRow
+ * Method:    compile_nvrtc
+ * Signature: (JLjava/lang/String;Ljava/lang/String;IIIZ)I
+ */
+JNIEXPORT jint JNICALL 
Java_org_apache_sysds_hops_codegen_cplan_CNodeRow_compile_1nvrtc
+               (JNIEnv *, jobject, jlong, jstring, jstring, jint, jint, jint, 
jboolean);
+
+/*
  * Class:     org_apache_sysds_hops_codegen_SpoofCompiler
  * Method:    compile_cuda_kernel
  * Signature: (JLjava/lang/String;Ljava/lang/String;)Z
diff --git a/src/main/cuda/spoof/cellwise.cu b/src/main/cuda/spoof/cellwise.cu
index f26161b..ca25fe5 100644
--- a/src/main/cuda/spoof/cellwise.cu
+++ b/src/main/cuda/spoof/cellwise.cu
@@ -58,6 +58,7 @@ struct SpoofCellwiseOp {
 //%NEED_GRIX%
 
 %BODY_dense%
+//printf("tid=%d a=%4.1f\n", threadIdx.x, a);
                return %OUT%;
        }
 };
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 6c179cb..b9be432 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -20,10 +20,10 @@
 package org.apache.sysds.common;
 
 import java.util.Arrays;
+import java.util.HashMap;
 
 import org.apache.sysds.runtime.DMLRuntimeException;
 
-
 public class Types
 {
        /**
@@ -174,14 +174,14 @@ public class Types
                }
        }
        
+       // these values need to match with their native counterparts (spoof 
cuda ops)
        public enum AggOp {
-               SUM, SUM_SQ,
-               PROD, SUM_PROD,
-               MIN, MAX,
-               TRACE, MEAN, VAR,
-               MAXINDEX, MININDEX,
-               COUNT_DISTINCT,
-               COUNT_DISTINCT_APPROX;
+               SUM(1), SUM_SQ(2), MIN(3), MAX(4),
+               PROD(5), SUM_PROD(6),
+               TRACE(7), MEAN(8), VAR(9),
+               MAXINDEX(10), MININDEX(11),
+               COUNT_DISTINCT(12),
+               COUNT_DISTINCT_APPROX(13);
                
                @Override
                public String toString() {
@@ -192,6 +192,27 @@ public class Types
                                default:     return name().toLowerCase();
                        }
                }
+               
+               private final int value;
+               private final static HashMap<Integer, AggOp> map = new 
HashMap<>();
+               
+               AggOp(int value) {
+                       this.value = value;
+               }
+               
+               static {
+                       for (AggOp aggOp : AggOp.values()) {
+                               map.put(aggOp.value, aggOp);
+                       }
+               }
+               
+               public static AggOp valueOf(int aggOp) {
+                       return map.get(aggOp);
+               }
+               
+               public int getValue() {
+                       return value;
+               }
        }
        
        // Operations that require 1 operand
diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index 4e8e61e..46a7042 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -238,8 +238,6 @@ public class SpoofCompiler {
                                        isLoaded = 
NativeHelper.loadLibraryHelperFromResource(libName);
 
                                if(isLoaded) {
-
-
                                        long ctx_ptr = 
initialize_cuda_context(0, local_tmp);
                                        if(ctx_ptr != 0) {
                                                
native_contexts.put(GeneratorAPI.CUDA, ctx_ptr);
@@ -289,14 +287,8 @@ public class SpoofCompiler {
                }
        }
 
-       private static boolean compile_cuda(String name, String src) {
-               return 
compile_cuda_kernel(native_contexts.get(GeneratorAPI.CUDA), name, src);
-       }
-
        private static native long initialize_cuda_context(int device_id, 
String resource_path);
 
-       private static native boolean compile_cuda_kernel(long ctx, String 
name, String src);
-
        private static native void destroy_cuda_context(long ctx, int 
device_id);
 
        //plan cache for cplan->compiled source to avoid unnecessary 
codegen/source code compile
@@ -514,18 +506,16 @@ public class SpoofCompiler {
                                Class<?> cla = 
planCache.getPlan(tmp.getValue());
                                
                                if( cla == null ) {
-                                       String src = "";
                                        String src_cuda = "";
-                                       boolean native_compiled_successfully = 
false;
-                                       src = tmp.getValue().codegen(false, 
GeneratorAPI.JAVA);
+                                       String src = 
tmp.getValue().codegen(false, GeneratorAPI.JAVA);
                                        cla = 
CodegenUtils.compileClass("codegen."+ tmp.getValue().getClassname(), src);
 
                                        if(API == GeneratorAPI.CUDA) {
                                                
if(tmp.getValue().isSupported(API)) {
                                                        src_cuda = 
tmp.getValue().codegen(false, GeneratorAPI.CUDA);
-                                                       
native_compiled_successfully = compile_cuda(tmp.getValue().getVarname(), 
src_cuda);
-                                                       
if(native_compiled_successfully)
-                                                               
CodegenUtils.putNativeOpData(new SpoofCUDA(src_cuda, tmp.getValue(), cla));
+                                                       int op_id = 
tmp.getValue().compile(API, src_cuda);
+                                                       if(op_id >= 0)
+                                                               
CodegenUtils.putNativeOpData(new SpoofCUDA(src_cuda, tmp.getValue(), op_id));
                                                        else {
                                                                LOG.warn("CUDA 
compilation failed, falling back to JAVA");
                                                                
tmp.getValue().setGeneratorAPI(GeneratorAPI.JAVA);
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
index f2f0179..9d1061f 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
@@ -22,6 +22,7 @@ package org.apache.sysds.hops.codegen.cplan;
 import java.util.ArrayList;
 
 import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
 import org.apache.sysds.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
 import org.apache.sysds.runtime.codegen.SpoofCellwise;
@@ -274,4 +275,14 @@ public class CNodeCell extends CNodeTpl
                return (api == GeneratorAPI.CUDA || api == GeneratorAPI.JAVA) 
&& _output.isSupported(api) &&
                        !(getSpoofAggOp() == SpoofCellwise.AggOp.SUM_SQ);
        }
+       
+       public int compile(GeneratorAPI api, String src) {
+               if(api == GeneratorAPI.CUDA)
+                       return 
compile_nvrtc(SpoofCompiler.native_contexts.get(api), _genVar, src, 
_type.getValue
+                                       (), 
+                               _aggOp != null ? _aggOp.getValue() : 0, 
_sparseSafe);
+               return -1;
+       }
+       
+       private native int compile_nvrtc(long context, String name, String src, 
int type, int agg_op, boolean sparseSafe);
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
index 895d945..c14c2c7 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
@@ -25,6 +25,7 @@ import java.util.Arrays;
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
@@ -216,4 +217,13 @@ public class CNodeMultiAgg extends CNodeTpl
                }
                return  is_supported;
        }
+       
+       public int compile(GeneratorAPI api, String src) {
+               if(api == GeneratorAPI.CUDA)
+                       return 
compile_nvrtc(SpoofCompiler.native_contexts.get(api), _genVar, src, 
_sparseSafe); // ToDo: compile MA
+               return -1;
+       }
+       
+       private native int compile_nvrtc(long context, String name, String src, 
boolean sparseSafe);
+       
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
index 6a3a647..955e9e2 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.codegen.cplan;
 
 import java.util.ArrayList;
 
+import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
 import org.apache.sysds.lops.MMTSJ;
 import org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
@@ -197,4 +198,13 @@ public class CNodeOuterProduct extends CNodeTpl
                }
                return  is_supported;
        }
+
+       public int compile(GeneratorAPI api, String src) {
+               if(api == GeneratorAPI.CUDA)
+                       return 
compile_nvrtc(SpoofCompiler.native_contexts.get(api), _genVar, src, 
_type.ordinal());
+               return -1;
+       }
+       
+       private native int compile_nvrtc(long context, String name, String src, 
int type);
+
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
index 74bde95..211c289 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
 import java.util.regex.Matcher;
 import java.util.stream.Collectors;
 
+import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
 import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
 import org.apache.sysds.hops.codegen.template.TemplateUtils;
@@ -52,6 +53,7 @@ public class CNodeRow extends CNodeTpl
        private RowType _type = null; //access pattern 
        private long _constDim2 = -1; //constant number of output columns
        private int _numVectors = -1; //number of intermediate vectors
+       private boolean _tb1 = false;
        
        public void setRowType(RowType type) {
                _type = type;
@@ -115,8 +117,8 @@ public class CNodeRow extends CNodeTpl
                //replace colvector information and number of vector 
intermediates
                tmp = tmp.replace("%TYPE%", _type.name());
                tmp = tmp.replace("%CONST_DIM2%", String.valueOf(_constDim2));
-               tmp = tmp.replace("%TB1%", String.valueOf(
-                       TemplateUtils.containsBinary(_output, 
BinType.VECT_MATRIXMULT)));
+               _tb1 = TemplateUtils.containsBinary(_output, 
BinType.VECT_MATRIXMULT); 
+               tmp = tmp.replace("%TB1%", String.valueOf(_tb1));
 
                if(api == GeneratorAPI.CUDA && _numVectors > 0) {
                        tmp = tmp.replace("//%HAS_TEMP_VECT%", ": public 
TempStorageImpl<T, NUM_TMP_VECT, TMP_VECT_LEN>");
@@ -230,4 +232,15 @@ public class CNodeRow extends CNodeTpl
        public boolean isSupported(GeneratorAPI api) {
                return (api == GeneratorAPI.CUDA || api == GeneratorAPI.JAVA) 
&& _output.isSupported(api);
        }
+
+       public int compile(GeneratorAPI api, String src) {
+               if(api == GeneratorAPI.CUDA)
+                       return 
compile_nvrtc(SpoofCompiler.native_contexts.get(api), _genVar, src, 
_type.getValue(), _constDim2, 
+                               _numVectors, _tb1);
+               return -1;
+       }
+       
+       private native int compile_nvrtc(long context, String name, String src, 
int type, long constDim2, int numVectors, 
+                       boolean TB1);
+
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeTpl.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeTpl.java
index 552b951..4f990c0 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeTpl.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeTpl.java
@@ -238,4 +238,6 @@ public abstract class CNodeTpl extends CNode implements 
Cloneable
        public GeneratorAPI getGeneratorAPI() { return api; }
        
        public void setGeneratorAPI(GeneratorAPI _api) { api = _api; }
+       
+       public abstract int compile(GeneratorAPI api, String src);
 }
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
index f3ccf12..88d91a7 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
@@ -37,15 +37,15 @@ public class SpoofCUDA extends SpoofOperator {
        private static final long serialVersionUID = -2161276866245388359L;
        
        private final CNodeTpl cnt;
-       private final Class<?> java_op;
        public final String name;
        public final String src;
+       public final int id;
 
-       public SpoofCUDA(String source, CNodeTpl cnode, Class<?> java_op) {
+       public SpoofCUDA(String source, CNodeTpl cnode, int _id) {
                name = "codegen." + cnode.getVarname();
                cnt = cnode;
                src = source;
-               this.java_op = java_op;
+               id = _id;
        }
 
        public String getName() {
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java
index f266f6d..89b502e 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java
@@ -23,6 +23,7 @@ import java.io.Serializable;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
@@ -49,11 +50,33 @@ public abstract class SpoofCellwise extends SpoofOperator 
implements Serializabl
 {
        private static final long serialVersionUID = 3442528770573293590L;
        
+       // these values need to match with their native counterparts (spoof 
cuda ops)
        public enum CellType {
-               NO_AGG,
-               FULL_AGG,
-               ROW_AGG,
-               COL_AGG,
+               NO_AGG(1),
+               FULL_AGG(2),
+               ROW_AGG(3),
+               COL_AGG(4);
+               
+               private final int value;
+               private final static HashMap<Integer, CellType> map = new 
HashMap<>();
+               
+               CellType(int value) {
+                       this.value = value;
+               }
+               
+               static {
+                       for (CellType cellType : CellType.values()) {
+                               map.put(cellType.value, cellType);
+                       }
+               }
+               
+               public static CellType valueOf(int cellType) {
+                       return map.get(cellType);
+               }
+               
+               public int getValue() {
+                       return value;
+               }                       
        }
        
        //redefinition of Hop.AggOp for cleaner imports in generate class
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java
index b49cb8d..bd37aaf 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.codegen;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
@@ -45,18 +46,42 @@ public abstract class SpoofRowwise extends SpoofOperator
 {
        private static final long serialVersionUID = 6242910797139642998L;
        
+       // Enum with explicit integer values
+       // Thanks to 
https://codingexplained.com/coding/java/enum-to-integer-and-integer-to-enum
+       // these values need to match with their native counterparts (spoof 
cuda ops)
        public enum RowType {
-               NO_AGG,       //no aggregation
-               NO_AGG_B1,    //no aggregation w/ matrix mult B1
-               NO_AGG_CONST, //no aggregation w/ expansion/contraction
-               FULL_AGG,     //full row/col aggregation
-               ROW_AGG,      //row aggregation (e.g., rowSums() or X %*% v)
-               COL_AGG,      //col aggregation (e.g., colSums() or t(y) %*% X)
-               COL_AGG_T,    //transposed col aggregation (e.g., t(X) %*% y)
-               COL_AGG_B1,   //col aggregation w/ matrix mult B1
-               COL_AGG_B1_T, //transposed col aggregation w/ matrix mult B1
-               COL_AGG_B1R,  //col aggregation w/ matrix mult B1 to row vector
-               COL_AGG_CONST;//col aggregation w/ expansion/contraction
+               NO_AGG(1),       //no aggregation
+               NO_AGG_B1(2),    //no aggregation w/ matrix mult B1
+               NO_AGG_CONST(3), //no aggregation w/ expansion/contraction
+               FULL_AGG(4),     //full row/col aggregation
+               ROW_AGG(5),      //row aggregation (e.g., rowSums() or X %*% v)
+               COL_AGG (6),      //col aggregation (e.g., colSums() or t(y) 
%*% X)
+               COL_AGG_T(7),    //transposed col aggregation (e.g., t(X) %*% y)
+               COL_AGG_B1(8),   //col aggregation w/ matrix mult B1
+               COL_AGG_B1_T(9), //transposed col aggregation w/ matrix mult B1
+               COL_AGG_B1R(10),  //col aggregation w/ matrix mult B1 to row 
vector
+               COL_AGG_CONST (11);//col aggregation w/ expansion/contraction
+               
+               private final int value;
+               private final static HashMap<Integer, RowType> map = new 
HashMap<>();
+               
+               RowType(int value) {
+                       this.value = value;
+               }
+               
+               static {
+                       for (RowType rowType : RowType.values()) {
+                               map.put(rowType.value, rowType);
+                       }
+               }
+               
+               public static RowType valueOf(int rowType) {
+                       return map.get(rowType);
+               }
+               
+               public int getValue() {
+                       return value;
+               }
                
                public boolean isColumnAgg() {
                        return this == COL_AGG || this == COL_AGG_T
diff --git 
a/src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java 
b/src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java
index 5fac4a9..10ea91e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java
@@ -171,15 +171,10 @@ public class CellwiseTmplTest extends AutomatedTestBase
        }
        
        @Test
-       public void testCodegenCellwise4() 
-       {
-               testCodegenIntegration( TEST_NAME4, false, ExecType.CP  );
-       }
+       public void testCodegenCellwise4() { testCodegenIntegration( 
TEST_NAME4, false, ExecType.CP ); }
        
        @Test
-       public void testCodegenCellwise5() {
-               testCodegenIntegration( TEST_NAME5, false, ExecType.CP  );
-       }
+       public void testCodegenCellwise5() { testCodegenIntegration( 
TEST_NAME5, false, ExecType.CP  ); }
        
        @Test
        public void testCodegenCellwise6() {
@@ -192,9 +187,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
        }
        
        @Test
-       public void testCodegenCellwise8() {
-               testCodegenIntegration( TEST_NAME8, false, ExecType.CP  );
-       }
+       public void testCodegenCellwise8() { testCodegenIntegration( 
TEST_NAME8, false, ExecType.CP  ); }
        
        @Test
        public void testCodegenCellwise9() {

Reply via email to