This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0725185  [SYSTEMDS-3024] Improve performance by batching data 
descriptor transfers
0725185 is described below

commit 072518513bc8454457701132b492e86e2928f44a
Author: Mark Dokter <[email protected]>
AuthorDate: Thu Apr 15 15:50:02 2021 +0200

    [SYSTEMDS-3024] Improve performance by batching data descriptor transfers
    
    The spoof cuda operators do several little cudaMemcpy() invocations per 
operator execution. By transferring all data in one go the overhead can be 
reduced. In addition, using asynchronous copies can further improve things and 
are a first step towards using more asynchronicity in the GPU operations.
    
    Closes #1567
---
 src/main/cuda/headers/Matrix.h                     |  42 +----
 src/main/cuda/headers/reduction.cuh                |  52 +-----
 src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp  |  19 +-
 src/main/cuda/spoof-launcher/SpoofCUDAContext.h    |  62 ++-----
 src/main/cuda/spoof-launcher/SpoofCellwise.h       |  94 +++++-----
 src/main/cuda/spoof-launcher/SpoofOperator.h       | 128 +++++++++----
 src/main/cuda/spoof-launcher/SpoofRowwise.h        |  38 ++--
 src/main/cuda/spoof-launcher/jni_bridge.cpp        | 177 ++++++++----------
 src/main/cuda/spoof-launcher/jni_bridge.h          |  57 +++---
 src/main/cuda/spoof/cellwise.cu                    |   3 +-
 src/main/cuda/spoof/rowwise.cu                     |   2 +
 .../apache/sysds/hops/codegen/cplan/CNodeCell.java |  15 +-
 .../sysds/hops/codegen/cplan/cuda/Unary.java       |   4 +-
 .../sysds/runtime/codegen/SpoofCUDACellwise.java   | 116 +++++-------
 .../sysds/runtime/codegen/SpoofCUDAOperator.java   | 205 +++++++++------------
 .../sysds/runtime/codegen/SpoofCUDARowwise.java    |  80 +++-----
 .../sysds/runtime/codegen/SpoofOperator.java       |   3 +
 .../instructions/gpu/SpoofCUDAInstruction.java     |  21 +--
 .../instructions/gpu/context/GPUObject.java        |   3 +-
 19 files changed, 505 insertions(+), 616 deletions(-)

diff --git a/src/main/cuda/headers/Matrix.h b/src/main/cuda/headers/Matrix.h
index 0446983..61ef939 100644
--- a/src/main/cuda/headers/Matrix.h
+++ b/src/main/cuda/headers/Matrix.h
@@ -26,19 +26,19 @@ using int32_t = int;
 
 template <typename T>
 struct Matrix {
-       int32_t nnz;
+       uint64_t nnz;
        uint32_t rows;
        uint32_t cols;
-       
        uint32_t* row_ptr;
        uint32_t* col_idx;
        T* data;
        
        typedef T value_type;
        
-       explicit Matrix(size_t* jvals) : nnz(jvals[0]), rows(jvals[1]), 
cols(jvals[2]),
-                       row_ptr(reinterpret_cast<uint32_t*>(jvals[3])),
-                       col_idx(reinterpret_cast<uint32_t*>((jvals[4]))), 
data(reinterpret_cast<T*>(jvals[5])) {}
+       explicit Matrix(uint8_t* jvals) : 
nnz(*reinterpret_cast<uint32_t*>(&jvals[0])),
+               rows(*reinterpret_cast<uint32_t*>(&jvals[8])), 
cols(*reinterpret_cast<uint32_t*>(&jvals[12])),
+                       row_ptr(reinterpret_cast<uint32_t*>(jvals[16])), 
col_idx(reinterpret_cast<uint32_t*>((jvals[24]))),
+                               data(static_cast<T*>(jvals[32])) {}
 };
 
 #ifdef __CUDACC__
@@ -72,7 +72,7 @@ public:
        
        __device__ void init(Matrix<T>* mat) { _mat = mat; }
        
-       __device__ uint32_t& nnz() { return _mat->nnz; }
+       __device__ uint32_t& nnz() { return return _mat->row_ptr == nullptr ? 
_mat->rows * _mat->cols : _mat->nnz; }
        __device__ uint32_t cols() { return _mat->cols; }
        __device__ uint32_t rows() { return _mat->rows; }
        
@@ -133,7 +133,7 @@ private:
        
        //ToDo sparse accessors
        __device__ uint32_t len_sparse() {
-               return _mat->nnz;
+               return _mat->row_ptr[_mat->rows];
        }
        
        __device__ uint32_t pos_sparse(uint32_t rix) {
@@ -227,34 +227,6 @@ public:
        }
 };
 
-template <typename T, int NUM_B>
-struct SpoofOp {
-       MatrixAccessor<T> a;
-       MatrixAccessor<T> b[NUM_B];
-       MatrixAccessor<T> c;
-       T* scalars;
-       uint32_t grix;
-       T* avals;
-       uint32_t* aix;
-       uint32_t alen;
-       
-       SpoofOp(Matrix<T>* A, Matrix<T>* B, Matrix<T>* C, T* scalars, T* 
tmp_stor, uint32_t grix) :
-                       scalars(scalars), grix(grix), avals(A->data), 
aix(A->col_idx) {
-               a.init(A);
-               c.init(C);
-               alen = a.row_len(grix);
-
-               if(B)
-                       for(auto i = 0; i < NUM_B; ++i)
-                               b[i].init(&(B[i]));
-       }
-       
-//     __device__ Vector<T>& getTempStorage(uint32_t len) {
-//             Vector<T>& vec = temp_rb.next();
-//             tvec.length = len;
-//             return vec;
-//     }
-};
 #endif // __CUDACC_RTC__
 
 #endif //SYSTEMDS_MATRIX_H
diff --git a/src/main/cuda/headers/reduction.cuh 
b/src/main/cuda/headers/reduction.cuh
index 3cd0a0b..7707fe2 100644
--- a/src/main/cuda/headers/reduction.cuh
+++ b/src/main/cuda/headers/reduction.cuh
@@ -22,6 +22,7 @@
 #define REDUCTION_CUH
 
 using uint = unsigned int;
+
 #include <cuda_runtime.h>
 
 #include "utils.cuh"
@@ -51,7 +52,9 @@ using uint = unsigned int;
  * @param SpoofCellwiseOp              initial value for the reduction variable
  */
 template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
-__device__ void FULL_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t N, T VT, ReductionOp reduction_op, SpoofCellwiseOp spoof_op) {
+__device__ void FULL_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t N, T VT, ReductionOp reduction_op,
+       SpoofCellwiseOp spoof_op)
+{
        auto sdata = shared_memory_proxy<T>();
 
        // perform first level of reduction,
@@ -66,12 +69,9 @@ __device__ void FULL_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t
        // number of active thread blocks (via gridDim).  More blocks will 
result
        // in a larger gridSize and therefore fewer elements per thread
        while (i < N) {
-//             printf("tid=%d i=%d N=%d, in->cols()=%d rix=%d\n", threadIdx.x, 
i, N, in->cols(), i/in->cols());
                v = reduction_op(v, spoof_op(*(in->vals(i)), i, i / in->cols(), 
i % in->cols()));
 
                if (i + blockDim.x < N) {
-                       //__syncthreads();
-                       //printf("loop fetch i(%d)+blockDim.x(%d)=%d, 
in=%f\n",i, blockDim.x, i + blockDim.x, g_idata[i + blockDim.x]);
                        v = reduction_op(v, spoof_op(*(in->vals(i+blockDim.x)), 
blockDim.x + i, (i+blockDim.x) / in->cols(), (i+blockDim.x) % in->cols()));
                }
 
@@ -116,40 +116,25 @@ __device__ void FULL_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t
                if (blockDim.x >= 64) {
                        smem[tid] = v = reduction_op(v, smem[tid + 32]);
                }
-//             if(tid<12)
-//                     printf("bid=%d tid=%d reduction result: %3.1f\n", 
blockIdx.x, tid, sdata[tid]);
-               
                if (blockDim.x >= 32) {
                        smem[tid] = v = reduction_op(v, smem[tid + 16]);
                }
-//             if(tid==0)
-//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                if (blockDim.x >= 16) {
                        smem[tid] = v = reduction_op(v, smem[tid + 8]);
                }
-//             if(tid==0)
-//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                if (blockDim.x >= 8) {
                        smem[tid] = v = reduction_op(v, smem[tid + 4]);
                }
-//             if(tid==0)
-//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                if (blockDim.x >= 4) {
                        smem[tid] = v = reduction_op(v, smem[tid + 2]);
                }
-//             if(tid==0)
-//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                if (blockDim.x >= 2) {
                        smem[tid] = v = reduction_op(v, smem[tid + 1]);
                }
-//             if(tid==0)
-//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
        }
 
         // write result for this block to global mem
         if (tid == 0) {
-//             if(gridDim.x < 10)
-//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
                out->val(0, blockIdx.x) = sdata[0];
         }
 }
@@ -174,19 +159,10 @@ __device__ void FULL_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t
  * the value before writing it to its final location in global memory for each
  * row
  */
-//template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
-//__device__ void ROW_AGG(
-//             T *g_idata, ///< input data stored in device memory (of size 
rows*cols)
-//             T *g_odata,  ///< output/temporary array store in device memory 
(of size
-//             /// rows*cols)
-//             uint rows,  ///< rows in input and temporary/output arrays
-//             uint cols,  ///< columns in input and temporary/output arrays
-//             T initialValue,  ///< initial value for the reduction variable
-//             ReductionOp reduction_op, ///< Reduction operation to perform 
(functor object)
-//             SpoofCellwiseOp spoof_op) ///< Operation to perform before 
assigning this
 template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
 __device__ void ROW_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t N, T VT,  ReductionOp reduction_op,
-                                          SpoofCellwiseOp spoof_op) {
+       SpoofCellwiseOp spoof_op)
+{
        auto sdata = shared_memory_proxy<T>();
 
        // one block per row
@@ -199,7 +175,6 @@ __device__ void ROW_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t
        uint32_t i = tid;
        uint block_offset = block * in->cols();
 
-//     T v = initialValue;
        T v = reduction_op.init();
        while (i < in->cols()) {
                v = reduction_op(v, spoof_op(in->val(block_offset + i), i, i / 
in->cols(), i % in->cols()));
@@ -283,16 +258,8 @@ __device__ void ROW_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t
  */
 template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
 __device__ void COL_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t N, T VT,  ReductionOp reduction_op,
-                                               SpoofCellwiseOp spoof_op) {
-//__device__ void COL_AGG(T *g_idata, ///< input data stored in device memory 
(of size rows*cols)
-//             T *g_odata,  ///< output/temporary array store in device memory 
(of size rows*cols)
-//             uint rows,  ///< rows in input and temporary/output arrays
-//             uint cols,  ///< columns in input and temporary/output arrays
-//             T initialValue,  ///< initial value for the reduction variable
-//             ReductionOp reduction_op, ///< Reduction operation to perform 
(functor object)
-//             SpoofCellwiseOp spoof_op) ///< Operation to perform before 
aggregation
-//
-//{
+       SpoofCellwiseOp spoof_op)
+{
        uint global_tid = blockIdx.x * blockDim.x + threadIdx.x;
        if (global_tid >= in->cols()) {
                return;
@@ -315,13 +282,12 @@ __device__ void NO_AGG(MatrixAccessor<T>* in, 
MatrixAccessor<T>* out, uint32_t N
        uint32_t gtid = blockIdx.x * blockDim.x + threadIdx.x;
        uint32_t first_idx = gtid * static_cast<uint32_t>(VT);
        uint32_t last_idx = min(first_idx + static_cast<uint32_t>(VT), N);
+
        #pragma unroll
        for(auto i = first_idx; i < last_idx; i++) {
                T a = in->hasData() ? in->vals(0)[i] : 0;
                T result = spoof_op(a, i, i / in->cols(), i % in->cols());
                out->vals(0)[i] = result;
-               //if(i < 4)
-               //      printf("tid=%d in=%4.3f res=%4.3f out=%4.3f r=%d\n", i, 
in->vals(0)[i], result, out->vals(0)[i], i/in->cols());
        }
 }
 
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp 
b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
index 233a4a1..2ef482d 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
@@ -77,10 +77,15 @@ size_t SpoofCUDAContext::initialize_cuda(uint32_t 
device_id, const char* resourc
        CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_d"));
        
ctx->reduction_kernels_d.insert(std::make_pair(std::make_pair(SpoofOperator::AggType::FULL_AGG,
 SpoofOperator::AggOp::MAX), func));
        
+       
CHECK_CUDART(cudaMallocHost(reinterpret_cast<void**>(&(ctx->staging_buffer)), 
ctx->default_mem_size));
+       
CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&(ctx->device_buffer)), 
ctx->default_mem_size));
+       ctx->current_mem_size = ctx->default_mem_size;
        return reinterpret_cast<size_t>(ctx);
 }
 
 void SpoofCUDAContext::destroy_cuda(SpoofCUDAContext *ctx, [[maybe_unused]] 
uint32_t device_id) {
+       cudaFreeHost(ctx->staging_buffer);
+       cudaFree(ctx->device_buffer);
        delete ctx;
        // cuda device is handled by jCuda atm
        //cudaDeviceReset();
@@ -116,15 +121,25 @@ size_t 
SpoofCUDAContext::compile(std::unique_ptr<SpoofOperator> op, const std::s
 
 template<typename T>
 CUfunction SpoofCUDAContext::getReductionKernel(const 
std::pair<SpoofOperator::AggType, SpoofOperator::AggOp> &key) {
-       return nullptr;
+       return nullptr; // generic case never used
 }
+
 template<>
 CUfunction SpoofCUDAContext::getReductionKernel<float>(const 
std::pair<SpoofOperator::AggType,
                SpoofOperator::AggOp> &key) {
        return reduction_kernels_f[key];
 }
+
 template<>
 CUfunction SpoofCUDAContext::getReductionKernel<double>(const 
std::pair<SpoofOperator::AggType,
                SpoofOperator::AggOp> &key) {
        return reduction_kernels_d[key];
-}
\ No newline at end of file
+}
+
+void SpoofCUDAContext::resize_staging_buffer(size_t size) {
+       cudaFreeHost(staging_buffer);
+       cudaFree(device_buffer);
+       
CHECK_CUDART(cudaMallocHost(reinterpret_cast<void**>(&(staging_buffer)), size));
+       CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&(device_buffer)), 
size));
+       current_mem_size = size;
+}
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.h 
b/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
index 696682f..e4b80c5 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
@@ -46,62 +46,34 @@ class SpoofCUDAContext {
        std::map<std::pair<SpoofOperator::AggType, SpoofOperator::AggOp>, 
CUfunction> reduction_kernels_f;
        std::map<std::pair<SpoofOperator::AggType, SpoofOperator::AggOp>, 
CUfunction> reduction_kernels_d;
 
-//     double handling_total, compile_total;
-       
        const std::string resource_path;
        const std::vector<std::string> include_paths;
        
 public:
+       size_t default_mem_size = 1024; // 1kb for hosting data pointers, 
scalars and some meta info. This default should
+                                                                       // not 
require resizing these buffers in most cases.
+       size_t current_mem_size = 0; // the actual staging buffer size (should 
be default unless there was a resize)
+       std::byte* staging_buffer{}; // pinned host mem for async transfers
+       std::byte* device_buffer{};  // this buffer holds the pointers to the 
data buffers
 
        explicit SpoofCUDAContext(const char* resource_path_, 
std::vector<std::string>  include_paths_) : reductions(nullptr),
-                       resource_path(resource_path_), 
include_paths(std::move(include_paths_))
-                       //,handling_total(0.0), compile_total(0.0)
-                       {}
+                       resource_path(resource_path_), 
include_paths(std::move(include_paths_)) { }
 
        static size_t initialize_cuda(uint32_t device_id, const char* 
resource_path_);
 
        static void destroy_cuda(SpoofCUDAContext *ctx, uint32_t device_id);
-       
+
        size_t compile(std::unique_ptr<SpoofOperator> op, const std::string 
&src);
-       
+
        template <typename T, typename CALL>
-       int launch(uint32_t opID, std::vector<Matrix<T>>& input, 
std::vector<Matrix<T>>& sides, Matrix<T>& output,
-                       T* scalars, uint32_t grix) {
-               // dp holds in/side/out/scalar pointers for GPU
-               DevMatPtrs<T> dp;
-
-               SpoofOperator* op = compiled_ops[opID].get();
-               
-               CHECK_CUDART(cudaMalloc((void **)&dp.in, sizeof(Matrix<T>) * 
input.size()));
-               CHECK_CUDART(cudaMemcpy(dp.in, 
reinterpret_cast<void*>(&input[0]), sizeof(Matrix<T>) * input.size(),
-                               cudaMemcpyHostToDevice));
-
-               if (!sides.empty()) {
-                       CHECK_CUDART(cudaMalloc(reinterpret_cast<void 
**>(&dp.sides), sizeof(Matrix<T>) * sides.size()));
-                       CHECK_CUDART(cudaMemcpy(dp.sides, &sides[0], 
sizeof(Matrix<T>)  * sides.size(), cudaMemcpyHostToDevice));
-               }
-               
-               if (op->isSparseSafe() && input.front().row_ptr != nullptr) {
-                       CHECK_CUDART(cudaMemcpy(output.row_ptr, 
input.front().row_ptr, (input.front().rows+1)*sizeof(uint32_t),
-                                       cudaMemcpyDeviceToDevice));
-               }
-#ifndef NDEBUG
-               std::cout << "output rows: " << output.rows << " cols: " << 
output.cols << " nnz: " << output.nnz << " format: " <<
-                               (output.row_ptr == nullptr ? "dense" : 
"sparse") << std::endl;
-#endif
-               size_t out_num_elements = output.rows * output.cols;
-               if(output.row_ptr)
-                       if(op->isSparseSafe() && output.nnz > 0)
-                               out_num_elements = output.nnz;
-               CHECK_CUDART(cudaMalloc((void **) &dp.out, sizeof(Matrix<T>)));
-               CHECK_CUDART(cudaMemset(output.data, 0, out_num_elements * 
sizeof(T)));
-               CHECK_CUDART(cudaMemcpy(dp.out, reinterpret_cast<void 
*>(&output), sizeof(Matrix<T>),
-                               cudaMemcpyHostToDevice));
-               
-               dp.scalars = scalars;
-
-               CALL::exec(this, op, input, sides, output, grix, dp);
-               
+       int launch() {
+
+               DataBufferWrapper dbw(staging_buffer, device_buffer);
+               SpoofOperator* op = compiled_ops[dbw.op_id()].get();
+               dbw.toDevice(op->stream);
+
+               CALL::exec(this, op, &dbw);
+
                return 0;
        }
        
@@ -109,6 +81,8 @@ public:
 
        template<typename T>
        CUfunction getReductionKernel(const std::pair<SpoofOperator::AggType, 
SpoofOperator::AggOp>& key);
+
+       void resize_staging_buffer(size_t size);
 };
 
 #endif // SPOOFCUDACONTEXT_H
diff --git a/src/main/cuda/spoof-launcher/SpoofCellwise.h 
b/src/main/cuda/spoof-launcher/SpoofCellwise.h
index fe7e9d6..9077840 100644
--- a/src/main/cuda/spoof-launcher/SpoofCellwise.h
+++ b/src/main/cuda/spoof-launcher/SpoofCellwise.h
@@ -27,8 +27,7 @@
 template<typename T>
 struct SpoofCellwiseFullAgg {
        
-       static void exec(SpoofCellwiseOp* op, uint32_t NT, uint32_t N, const 
std::string& op_name,
-                       std::vector<Matrix<T>>& sides, uint32_t grix,  
DevMatPtrs<T>& dp) {
+       static void exec(SpoofCellwiseOp* op, uint32_t NT, uint32_t N, const 
std::string& op_name, DataBufferWrapper* dbw) {
                T value_type;
                
                // num ctas
@@ -46,14 +45,15 @@ struct SpoofCellwiseFullAgg {
                                                  << std::endl;
 #endif
                CHECK_CUDA(op->program.get()->kernel(op_name)
-                                                  
.instantiate(type_of(value_type), std::max(static_cast<size_t>(1), 
sides.size()))
-                                                  .configure(grid, block, 
shared_mem_size)
-                                                  .launch(dp.in, dp.sides, 
dp.out, dp.scalars, N, grix));
+                                                  
.instantiate(type_of(value_type), std::max(static_cast<uint32_t>(1u), 
dbw->num_sides()))
+                                                  .configure(grid, block, 
shared_mem_size, op->stream)
+                                                  .launch(dbw->d_in<T>(0), 
dbw->d_sides<T>(), dbw->d_out<T>(), dbw->d_scalars<T>(), N, dbw->grix()));
                
                if(NB > 1) {
                        N = NB;
                        while (NB > 1) {
-                               void* args[3] = { &dp.out, &dp.out, &N};
+                               Matrix<T>* out = dbw->d_out<T>();
+                               void* args[3] = { &out, &out, &N};
                                
                                NB = std::ceil((N + NT * 2 - 1) / (NT * 2));
 #ifndef NDEBUG
@@ -64,7 +64,7 @@ struct SpoofCellwiseFullAgg {
                     << N << " elements"
                     << std::endl;
 #endif
-                               CHECK_CUDA(cuLaunchKernel(op->agg_kernel,NB, 1, 
1, NT, 1, 1, shared_mem_size, nullptr, args, nullptr));
+                               CHECK_CUDA(cuLaunchKernel(op->agg_kernel,NB, 1, 
1, NT, 1, 1, shared_mem_size, op->stream, args, nullptr));
                                N = NB;
                        }
                }
@@ -74,12 +74,11 @@ struct SpoofCellwiseFullAgg {
 
 template<typename T>
 struct SpoofCellwiseRowAgg {
-       static void exec(SpoofOperator *op, uint32_t NT, uint32_t N, const 
std::string &op_name,
-                         std::vector<Matrix<T>> &input, std::vector<Matrix<T>> 
&sides, uint32_t grix, DevMatPtrs<T>& dp) {
+       static void exec(SpoofOperator *op, uint32_t NT, uint32_t N, const 
std::string &op_name, DataBufferWrapper* dbw) {
                T value_type;
                
                // num ctas
-               uint32_t NB = input.front().rows;
+               uint32_t NB = dbw->h_in<T>(0)->rows;
                dim3 grid(NB, 1, 1);
                dim3 block(NT, 1, 1);
                uint32_t shared_mem_size = NT * sizeof(T);
@@ -90,9 +89,9 @@ struct SpoofCellwiseRowAgg {
                                        << N << " elements" << std::endl;
 #endif
                CHECK_CUDA(op->program->kernel(op_name)
-                                                  
.instantiate(type_of(value_type), std::max(static_cast<size_t>(1), 
sides.size()))
-                                                  .configure(grid, block, 
shared_mem_size)
-                                                  .launch(dp.in, dp.sides, 
dp.out, dp.scalars, N, grix));
+                                                  
.instantiate(type_of(value_type), std::max(static_cast<uint32_t>(1u), 
dbw->num_sides()))
+                                                  .configure(grid, block, 
shared_mem_size, op->stream)
+                                                  .launch(dbw->d_in<T>(0), 
dbw->d_sides<T>(), dbw->d_out<T>(), dbw->d_scalars<T>(), N, dbw->grix()));
                
        }
 };
@@ -100,8 +99,7 @@ struct SpoofCellwiseRowAgg {
 
 template<typename T>
 struct SpoofCellwiseColAgg {
-       static void exec(SpoofOperator* op, uint32_t NT, uint32_t N, const 
std::string& op_name,
-                                        std::vector<Matrix<T>>& sides, 
uint32_t grix, DevMatPtrs<T>& dp) {
+       static void exec(SpoofOperator* op, uint32_t NT, uint32_t N, const 
std::string& op_name, DataBufferWrapper* dbw) {
                T value_type;
                
                // num ctas
@@ -116,9 +114,9 @@ struct SpoofCellwiseColAgg {
                                                << N << " elements" << 
std::endl;
 #endif
                CHECK_CUDA(op->program->kernel(op_name)
-                                                  
.instantiate(type_of(value_type), std::max(static_cast<size_t>(1), 
sides.size()))
-                                                  .configure(grid, block, 
shared_mem_size)
-                                                  .launch(dp.in, dp.sides, 
dp.out, dp.scalars, N, grix));
+                                                  
.instantiate(type_of(value_type), std::max(static_cast<uint32_t>(1u), 
dbw->num_sides()))
+                                                  .configure(grid, block, 
shared_mem_size, op->stream)
+                                                  .launch(dbw->d_in<T>(0), 
dbw->d_sides<T>(), dbw->d_out<T>(), dbw->d_scalars<T>(), N, dbw->grix()));
                
        }
 };
@@ -126,72 +124,78 @@ struct SpoofCellwiseColAgg {
 
 template<typename T>
 struct SpoofCellwiseNoAgg {
-       static void exec(SpoofOperator *op, uint32_t NT, uint32_t N, const 
std::string &op_name,
-                         std::vector<Matrix<T>> &input, std::vector<Matrix<T>> 
&sides, uint32_t grix, DevMatPtrs<T>& dp) {
+       static void exec(SpoofOperator *op, uint32_t NT, uint32_t N, const 
std::string &op_name, DataBufferWrapper* dbw) {
                T value_type;
-               bool sparse_input = input.front().row_ptr != nullptr;
+               bool sparse_input = dbw->h_in<T>(0)->row_ptr != nullptr;
                
                // num ctas
-               // ToDo: adaptive VT
+               // ToDo? adaptive VT
                const uint32_t VT = 1;
                uint32_t NB = std::ceil((N + NT * VT - 1) / (NT * VT));
                if(sparse_input)
-                       NB = input.front().rows;
+                       NB = dbw->h_in<T>(0)->rows;
                dim3 grid(NB, 1, 1);
                dim3 block(NT, 1, 1);
                uint32_t shared_mem_size = 0;
 
 #ifndef NDEBUG
+               std::cout << "output rows: " << dbw->h_out<T>()->rows << " 
cols: " << dbw->h_out<T>()->cols << " nnz: " <<
+                       (dbw->h_out<T>()->row_ptr == nullptr ? 
dbw->h_out<T>()->rows * dbw->h_out<T>()->cols :
+                               dbw->h_out<T>()->nnz) << " format: " << 
(dbw->h_out<T>()->row_ptr == nullptr
+                                       ? "dense" : "sparse") << std::endl;
+
                if(sparse_input) {
                                std::cout << "launching sparse spoof cellwise 
kernel " << op_name << " with " << NT * NB
-                                                 << " threads in " << NB << " 
blocks without aggregation for " << N << " elements"
-                                                 << std::endl;
+                                       << " threads in " << NB << " blocks 
without aggregation for " << N << " elements" << std::endl;
                }
                else {
-                       std::cout << "launching spoof cellwise kernel " << 
op_name << " with " << NT * NB
-                                         << " threads in " << NB << " blocks 
without aggregation for " << N << " elements"
-                                         << std::endl;
+                       std::cout << "launching spoof cellwise kernel " << 
op_name << " with " << NT * NB << " threads in " << NB <<
+                               " blocks without aggregation for " << N << " 
elements" << std::endl;
                }
 #endif
-               
                CHECK_CUDA(op->program->kernel(op_name)
-                                                  
.instantiate(type_of(value_type), std::max(static_cast<size_t>(1), 
sides.size()))
-                                                  .configure(grid, block, 
shared_mem_size)
-                                                  .launch(dp.in, dp.sides, 
dp.out, dp.scalars, N, grix));
+                                                  
.instantiate(type_of(value_type), std::max(static_cast<uint32_t>(1u), 
dbw->num_sides()))
+                                                  .configure(grid, block, 
shared_mem_size, op->stream)
+                                                  .launch(dbw->d_in<T>(0), 
dbw->d_sides<T>(), dbw->d_out<T>(), dbw->d_scalars<T>(), N, dbw->grix()));
+
+               // copy over row indices from input to output if appropriate
+               if (op->isSparseSafe() && dbw->h_in<T>(0)->row_ptr != nullptr) {
+                       // src/dst information (pointer address) is stored in 
*host* buffer!
+                       CHECK_CUDART(cudaMemcpyAsync(dbw->h_out<T>()->row_ptr, 
dbw->h_in<T>(0)->row_ptr,
+                               (dbw->h_in<T>(0)->rows+1) * sizeof(uint32_t), 
cudaMemcpyDeviceToDevice, op->stream));
+                       CHECK_CUDART(cudaMemcpyAsync(dbw->h_out<T>()->col_idx, 
dbw->h_in<T>(0)->col_idx,
+                                                                               
 (dbw->h_in<T>(0)->nnz) * sizeof(uint32_t), cudaMemcpyDeviceToDevice, 
op->stream));
+               }
        }
 };
 
 template<typename T>
 struct SpoofCellwise {
-       static void exec(SpoofCUDAContext* ctx, SpoofOperator* _op, 
std::vector<Matrix<T>>& input,
-                       std::vector<Matrix<T>>& sides, Matrix<T>& output, 
uint32_t grix,
-                       DevMatPtrs<T>& dp)  {
-               
-               T value_type;
+       static void exec(SpoofCUDAContext* ctx, SpoofOperator* _op, 
DataBufferWrapper* dbw) {
                auto* op = dynamic_cast<SpoofCellwiseOp*>(_op);
-               bool sparse_input = input.front().row_ptr != nullptr;
+               bool sparse_input = dbw->h_in<T>(0)->row_ptr != nullptr;
                uint32_t NT = 256; // ToDo: num threads
-               uint32_t N = input.front().rows * input.front().cols;
+               uint32_t N = dbw->h_in<T>(0)->rows * dbw->h_in<T>(0)->cols;
                std::string op_name(op->name + "_DENSE");
                if(sparse_input) {
                        op_name = std::string(op->name + "_SPARSE");
-                       if(op->isSparseSafe() && input.front().nnz > 0)
-                               N = input.front().nnz;
+                       if(op->isSparseSafe() && dbw->h_in<T>(0)->nnz > 0)
+                               N = dbw->h_in<T>(0)->nnz;
                }
                
                switch(op->agg_type) {
                        case SpoofOperator::AggType::FULL_AGG:
                                op->agg_kernel = ctx->template 
getReductionKernel<T>(std::make_pair(op->agg_type, op->agg_op));
-                               SpoofCellwiseFullAgg<T>::exec(op, NT, N, 
op_name, sides, grix, dp);
+                               SpoofCellwiseFullAgg<T>::exec(op, NT, N, 
op_name, dbw);
                                break;
                        case SpoofOperator::AggType::ROW_AGG:
-                               SpoofCellwiseRowAgg<T>::exec(op, NT, N, 
op_name, input, sides, grix, dp);
+                               SpoofCellwiseRowAgg<T>::exec(op, NT, N, 
op_name, dbw);
                                break;
                        case SpoofOperator::AggType::COL_AGG:
-                               SpoofCellwiseColAgg<T>::exec(op, NT, N, 
op_name, sides, grix, dp);
+                               SpoofCellwiseColAgg<T>::exec(op, NT, N, 
op_name, dbw);
                                break;
                        case SpoofOperator::AggType::NO_AGG:
-                               SpoofCellwiseNoAgg<T>::exec(op, NT, N, op_name, 
input, sides, grix, dp);
+                               SpoofCellwiseNoAgg<T>::exec(op, NT, N, op_name, 
dbw);
                                break;
                        default:
                                throw std::runtime_error("unknown cellwise agg 
type" + std::to_string(static_cast<int>(op->agg_type)));
diff --git a/src/main/cuda/spoof-launcher/SpoofOperator.h 
b/src/main/cuda/spoof-launcher/SpoofOperator.h
index 0ccc633..f256e81 100644
--- a/src/main/cuda/spoof-launcher/SpoofOperator.h
+++ b/src/main/cuda/spoof-launcher/SpoofOperator.h
@@ -28,18 +28,24 @@
 #include "host_utils.h"
 #include "Matrix.h"
 
+// these two constants have equivalents in Java code:
+const uint32_t JNI_MAT_ENTRY_SIZE = 40;
+const uint32_t TRANSFERRED_DATA_HEADER_SIZE = 32;
+
 struct SpoofOperator {
-//     enum class OpType : int { CW, RA, MA, OP, NONE };
        enum class AggType : int { NO_AGG, FULL_AGG, ROW_AGG, COL_AGG };
        enum class AggOp : int { SUM, SUM_SQ, MIN, MAX };
        enum class RowType : int { FULL_AGG = 4 };
-       
-//     OpType op_type;
+
        std::string name;
-//     jitify::Program program;
        std::unique_ptr<jitify::Program> program;
        
        [[nodiscard]] virtual bool isSparseSafe() const = 0;
+
+       cudaStream_t stream{};
+       
+       SpoofOperator() { CHECK_CUDART(cudaStreamCreate(&stream));}
+       virtual ~SpoofOperator() {CHECK_CUDART(cudaStreamDestroy(stream));}
 };
 
 struct SpoofCellwiseOp : public SpoofOperator {
@@ -58,46 +64,90 @@ struct SpoofRowwiseOp : public SpoofOperator {
        int32_t const_dim2;
        RowType row_type;
        
-       SpoofRowwiseOp(RowType rt, bool tb1, uint32_t ntv, int32_t cd2)  : 
row_type(rt), TB1(tb1), num_temp_vectors(ntv),
+       SpoofRowwiseOp(RowType rt, bool tb1, uint32_t ntv, int32_t cd2) : 
row_type(rt), TB1(tb1), num_temp_vectors(ntv),
                        const_dim2(cd2) {}
                        
        [[nodiscard]] bool isSparseSafe() const override { return false; }
 };
 
-template<typename T>
-struct DevMatPtrs {
-       Matrix<T>* ptrs[3] = {0,0,0};
-       
-       Matrix<T>*& in = ptrs[0];
-       Matrix<T>*& sides = ptrs[1];
-       Matrix<T>*& out = ptrs[2];
-       T* scalars{};
-
-       ~DevMatPtrs() {
-#ifndef NDEBUG
-               std::cout << "~DevMatPtrs() before cudaFree:\n";
-               int i = 0;
-               for (auto& p : ptrs) {
-                       std::cout << " p[" << i << "]=" << p;
-                       i++;
-               }
-               std::cout << std::endl;
-#endif
-               for (auto& p : ptrs) {
-                       if (p) {
-                               CHECK_CUDART(cudaFree(p));
-                               p = nullptr;
-                       }
-               }
-#ifndef NDEBUG
-               std::cout << "~DevMatPtrs() after cudaFree:\n";
-               i = 0;
-               for (auto& p : ptrs) {
-                       std::cout << " p[" << i << "]=" << p;
-                       i++;
-               }
-               std::cout << std::endl;
-#endif
+struct DataBufferWrapper {
+       std::byte* staging_buffer;
+       std::byte* device_buffer;
+
+
+       template<typename T>
+       Matrix<T>* in(std::byte* buffer, uint32_t idx) {
+               return 
reinterpret_cast<Matrix<T>*>(&buffer[TRANSFERRED_DATA_HEADER_SIZE + idx * 
JNI_MAT_ENTRY_SIZE]);
+       }
+
+       template<typename T>
+       Matrix<T>* sides(std::byte* buffer) {
+               return 
reinterpret_cast<Matrix<T>*>(&buffer[TRANSFERRED_DATA_HEADER_SIZE + 
num_inputs() * JNI_MAT_ENTRY_SIZE]);
+       }
+
+       template<typename T>
+       Matrix<T>* out(std::byte* buffer) {
+               return 
reinterpret_cast<Matrix<T>*>(&buffer[TRANSFERRED_DATA_HEADER_SIZE + 
(num_inputs() + num_sides())
+                       * JNI_MAT_ENTRY_SIZE]);
+       }
+
+       template<typename T>
+       T* scalars(std::byte* buffer, uint32_t idx) {
+               return 
reinterpret_cast<T*>(&(buffer[TRANSFERRED_DATA_HEADER_SIZE + (num_inputs() + 
num_sides() + 1)
+                       * JNI_MAT_ENTRY_SIZE + idx * sizeof(T)]));
+       }
+
+public:
+       explicit DataBufferWrapper(std::byte* staging, std::byte* dev_buf) : 
staging_buffer(staging),
+               device_buffer(dev_buf) { }
+
+       void toDevice(cudaStream_t &stream) const {
+               CHECK_CUDART(cudaMemcpyAsync(device_buffer, staging_buffer, 
*reinterpret_cast<uint32_t*>(&staging_buffer[0]),
+                       cudaMemcpyHostToDevice, stream));
+       }
+
+       template<typename T>
+       Matrix<T>* d_in(uint32_t num) { return in<T>(device_buffer, num); }
+
+       template<typename T>
+       Matrix<T>* h_in(uint32_t num) { return in<T>(staging_buffer, num); }
+
+       template<typename T>
+       Matrix<T>* d_sides() { return sides<T>(device_buffer); }
+
+       template<typename T>
+       Matrix<T>* h_sides() { return sides<T>(staging_buffer); }
+
+       template<typename T>
+       Matrix<T>* d_out() { return out<T>(device_buffer); }
+
+       template<typename T>
+       Matrix<T>* h_out() { return out<T>(staging_buffer); }
+
+       template<typename T>
+       T* d_scalars(uint32_t idx = 0) { return scalars<T>(device_buffer, idx); 
}
+
+       template<typename T>
+       T* h_scalars(uint32_t idx = 0) { return scalars<T>(staging_buffer, 
idx); }
+
+       uint32_t op_id() const {
+               return 
*reinterpret_cast<uint32_t*>(&staging_buffer[sizeof(int)]);
+       }
+
+       uint64_t grix() const {
+               return *reinterpret_cast<uint64_t*>(&staging_buffer[2 * 
sizeof(int)]);
+       }
+
+       [[nodiscard]] uint32_t num_inputs() const {
+               return *reinterpret_cast<uint32_t*>(&staging_buffer[3 * 
sizeof(int)]);
+       }
+
+       uint32_t num_sides() const {
+               return *reinterpret_cast<uint32_t*>(&staging_buffer[4 * 
sizeof(int)]);
+       }
+
+       uint32_t num_scalars() const {
+               return *reinterpret_cast<uint32_t*>(&staging_buffer[6 * 
sizeof(int)]);
        }
 };
 
diff --git a/src/main/cuda/spoof-launcher/SpoofRowwise.h 
b/src/main/cuda/spoof-launcher/SpoofRowwise.h
index 1295314..4465ac9 100644
--- a/src/main/cuda/spoof-launcher/SpoofRowwise.h
+++ b/src/main/cuda/spoof-launcher/SpoofRowwise.h
@@ -27,27 +27,34 @@
 template <typename T>
 struct SpoofRowwise {
        
-       static void exec([[maybe_unused]] SpoofCUDAContext* ctx, SpoofOperator* 
_op, std::vector<Matrix<T>>& input,
-                       std::vector<Matrix<T>>& sides, Matrix<T>& output, 
uint32_t grix, DevMatPtrs<T>& dp)  {
+       static void exec([[maybe_unused]] SpoofCUDAContext* ctx, SpoofOperator* 
_op, DataBufferWrapper* dbw)  {
                uint32_t NT=256;
                T value_type;
-               bool sparse_input = input.front().row_ptr != nullptr;
+               bool sparse_input = dbw->h_in<T>(0)->row_ptr != nullptr;
                auto* op = dynamic_cast<SpoofRowwiseOp*>(_op);
-               dim3 grid(input.front().rows, 1, 1);
+               dim3 grid(dbw->h_in<T>(0)->rows, 1, 1);
                dim3 block(NT, 1, 1);
                unsigned int shared_mem_size = NT * sizeof(T);
-               
+
+               size_t out_num_elements = dbw->h_out<T>()->rows * 
dbw->h_out<T>()->cols;
+               if(dbw->h_out<T>()->row_ptr)
+                       if(op->isSparseSafe() && dbw->h_out<T>()->nnz > 0)
+                               out_num_elements = dbw->h_out<T>()->nnz;
+               //ToDo: only memset output when there is an output operation 
that *adds* to the buffer
+               CHECK_CUDART(cudaMemsetAsync(dbw->h_out<T>()->data, 0, 
out_num_elements * sizeof(T), op->stream));
+
+               //ToDo: handle this in JVM
                uint32_t tmp_len = 0;
                uint32_t temp_buf_size = 0;
                T* d_temp = nullptr;
                if(op->num_temp_vectors > 0) {
-                       tmp_len = std::max(input.front().cols, op->const_dim2 < 
0 ? 0 : static_cast<uint32_t>(op->const_dim2));
-                       temp_buf_size = op->num_temp_vectors * tmp_len * 
input.front().rows * sizeof(T);
+                       tmp_len = std::max(dbw->h_in<T>(0)->cols, 
op->const_dim2 < 0 ? 0u : static_cast<uint32_t>(op->const_dim2));
+                       temp_buf_size = op->num_temp_vectors * tmp_len * 
dbw->h_in<T>(0)->rows * sizeof(T);
 #ifndef NDEBUG
                        std::cout << "num_temp_vect: " << op->num_temp_vectors 
<< " temp_buf_size: " << temp_buf_size << " tmp_len: " << tmp_len << std::endl;
 #endif
                        
CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&d_temp), temp_buf_size));
-                       CHECK_CUDART(cudaMemset(d_temp, 0, temp_buf_size));
+                       CHECK_CUDART(cudaMemsetAsync(d_temp, 0, temp_buf_size, 
op->stream));
                }
                
                std::string op_name(op->name + "_DENSE");
@@ -56,21 +63,18 @@ struct SpoofRowwise {
 
 #ifndef NDEBUG
                // ToDo: connect output to SystemDS logging facilities
-               std::cout << "launching spoof rowwise kernel " << op_name << " 
with " << NT * input.front().rows << " threads in "
-                               << input.front().rows << " blocks and " << 
shared_mem_size << " bytes of shared memory for "
-                               << input.front().rows << " cols processed by " 
<< NT << " threads per row, adding "
+               std::cout << "launching spoof rowwise kernel " << op_name << " 
with " << NT * dbw->h_in<T>(0)->rows << " threads in "
+                               << dbw->h_in<T>(0)->rows << " blocks and " << 
shared_mem_size << " bytes of shared memory for "
+                               << dbw->h_in<T>(0)->rows << " cols processed by 
" << NT << " threads per row, adding "
                                << temp_buf_size / 1024 << " kb of temp buffer 
in global memory." <<  std::endl;
 #endif
                CHECK_CUDA(op->program->kernel(op_name)
-                                                  
.instantiate(type_of(value_type), std::max(static_cast<size_t>(1), 
sides.size()), op->num_temp_vectors, tmp_len)
-                                                  .configure(grid, block, 
shared_mem_size)
-                                                  .launch(dp.in, dp.sides, 
dp.out, dp.scalars, d_temp, grix));
+                                                  
.instantiate(type_of(value_type), std::max(static_cast<uint32_t>(1), 
dbw->num_sides()), op->num_temp_vectors, tmp_len)
+                                                  .configure(grid, block, 
shared_mem_size, op->stream)
+                                                  .launch(dbw->d_in<T>(0), 
dbw->d_sides<T>(), dbw->d_out<T>(), dbw->d_scalars<T>(), d_temp, dbw->grix()));
                
                if(op->num_temp_vectors > 0)
                        CHECK_CUDART(cudaFree(d_temp));
-               
-//             if (op->TB1)
-//                     CHECK_CUDART(cudaFree(b1_transposed));
        }
 };
 
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.cpp 
b/src/main/cuda/spoof-launcher/jni_bridge.cpp
index 9a808a6..5134d5e 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.cpp
+++ b/src/main/cuda/spoof-launcher/jni_bridge.cpp
@@ -23,69 +23,86 @@
 #include "SpoofRowwise.h"
 
 // JNI Methods to get/release arrays
-#define GET_ARRAY(env, input)((void *)env->GetPrimitiveArrayCritical(input, 
nullptr))
+//#define GET_ARRAY(env, input)((void *)env->GetPrimitiveArrayCritical(input, 
nullptr))
+//#define RELEASE_ARRAY(env, java, 
cpp)(env->ReleasePrimitiveArrayCritical(java, cpp, 0))
 
-#define RELEASE_ARRAY(env, java, cpp)(env->ReleasePrimitiveArrayCritical(java, 
cpp, 0))
+jclass jcuda_pointer_class;
+jclass jcuda_native_pointer_class;
+jfieldID pointer_buffer_field;
+jfieldID native_pointer_field;
 
 // error output helper
-void printException(const std::string& name, const std::exception& e, bool 
compile = false) {
+void printException(const std::string &name, const std::exception &e, bool 
compile = false) {
        std::string type = compile ? "compiling" : "executing";
-       std::cout << "std::exception while " << type << "  SPOOF CUDA operator 
" << name << ":\n" << e.what() << std::endl;
+       std::cerr << "std::exception while " << type << "  SPOOF CUDA operator 
" << name << ":\n" << e.what() << std::endl;
 }
 
 
-// a pod struct to have names for the passed pointers
-template<typename T>
-struct LaunchMetadata {
-       const T& opID;
-       const T& grix;
-       const size_t& num_inputs;
-       const size_t& num_sides;
-       
-       // num entries describing one matrix (6 entries):
-       // {nnz,rows,cols,row_ptr,col_idxs,data}
-       const size_t& entry_size;
-       const T& num_scalars;
-       
-       explicit LaunchMetadata(const size_t* jvals) : opID(jvals[0]), 
grix(jvals[1]), num_inputs(jvals[2]),
-                       num_sides(jvals[3]), entry_size(jvals[4]), 
num_scalars(jvals[5]) {}
-};
-
-
-[[maybe_unused]] JNIEXPORT jlong JNICALL
-Java_org_apache_sysds_hops_codegen_SpoofCompiler_initialize_1cuda_1context(
-       JNIEnv *jenv, [[maybe_unused]] jobject jobj, jint device_id, jstring 
resource_path) {
+[[maybe_unused]] JNIEXPORT jlong JNICALL 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_initialize_1cuda_1context
+       (JNIEnv *jenv, [[maybe_unused]] jobject jobj, jint device_id, jstring 
resource_path)
+{
        const char *cstr_rp = jenv->GetStringUTFChars(resource_path, nullptr);
        size_t ctx = SpoofCUDAContext::initialize_cuda(device_id, cstr_rp);
        jenv->ReleaseStringUTFChars(resource_path, cstr_rp);
+
+       // fetch some jcuda class handles
+       jcuda_pointer_class = jenv->FindClass("jcuda/Pointer");
+       jcuda_native_pointer_class = 
jenv->FindClass("jcuda/NativePointerObject");
+       pointer_buffer_field = jenv->GetFieldID(jcuda_pointer_class, "buffer", 
"Ljava/nio/Buffer;");
+       native_pointer_field = jenv->GetFieldID(jcuda_native_pointer_class, 
"nativePointer", "J");
+
+       // explicit cast to make compiler and linter happy
        return static_cast<jlong>(ctx);
 }
 
 
-[[maybe_unused]] JNIEXPORT void JNICALL
-Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context(
-               [[maybe_unused]] JNIEnv *jenv, [[maybe_unused]] jobject jobj, 
jlong ctx, jint device_id) {
+[[maybe_unused]] JNIEXPORT void JNICALL 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context
+       ([[maybe_unused]] JNIEnv *jenv, [[maybe_unused]] jobject jobj, jlong 
ctx, jint device_id) {
        SpoofCUDAContext::destroy_cuda(reinterpret_cast<SpoofCUDAContext 
*>(ctx), device_id);
 }
 
+[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofOperator_getNativeStagingBuffer
+       (JNIEnv *jenv, [[maybe_unused]] jclass jobj, jobject ptr, jlong _ctx, 
jint size) {
+       std::string operator_name("SpoofOperator_getNativeStagingBuffer");
+       try {
+               // retrieve data handles from JVM
+               auto *ctx = reinterpret_cast<SpoofCUDAContext *>(_ctx);
+               if (size > ctx->current_mem_size)
+                       ctx->resize_staging_buffer(size);
+
+               jobject object = jenv->NewDirectByteBuffer(ctx->staging_buffer, 
size);
+               jenv->SetObjectField(ptr, pointer_buffer_field, object);
+               jenv->SetLongField(ptr,native_pointer_field, 
reinterpret_cast<jlong>(ctx->staging_buffer));
+
+               return 0;
+       }
+       catch (std::exception & e) {
+               printException(operator_name, e);
+       }
+       catch (...) {
+               printException(operator_name, std::runtime_error("unknown 
exception"), true);
+       }
+       return -1;
+}
 
 template<typename TEMPLATE>
-int compile_spoof_operator(JNIEnv *jenv, [[maybe_unused]] jobject jobj, jlong 
_ctx, jstring name, jstring src, TEMPLATE op) {
+int compile_spoof_operator
+       (JNIEnv *jenv, [[maybe_unused]] jobject jobj, jlong _ctx, jstring name, 
jstring src, TEMPLATE op) {
        std::string operator_name;
        try {
                auto *ctx = reinterpret_cast<SpoofCUDAContext *>(_ctx);
                const char *cstr_name = jenv->GetStringUTFChars(name, nullptr);
                const char *cstr_src = jenv->GetStringUTFChars(src, nullptr);
                operator_name = cstr_name;
-
                op->name = operator_name;
+
                int status = ctx->compile(std::move(op), cstr_src);
-               
+
                jenv->ReleaseStringUTFChars(src, cstr_src);
                jenv->ReleaseStringUTFChars(name, cstr_name);
                return status;
        }
-       catch (std::exception& e) {
+       catch (std::exception &e) {
                printException(operator_name, e, true);
        }
        catch (...) {
@@ -96,73 +113,41 @@ int compile_spoof_operator(JNIEnv *jenv, [[maybe_unused]] 
jobject jobj, jlong _c
 
 
 [[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_hops_codegen_cplan_CNodeCell_compile_1nvrtc
-               (JNIEnv *jenv, jobject jobj, jlong ctx, jstring name, jstring 
src, jint type, jint agg_op,
-                               jboolean sparseSafe) {
-       
+       (JNIEnv *jenv, jobject jobj, jlong ctx, jstring name, jstring src, jint 
type, jint agg_op, jboolean sparseSafe) {
        std::unique_ptr<SpoofCellwiseOp> op = 
std::make_unique<SpoofCellwiseOp>(SpoofOperator::AggType(type),
-                       SpoofOperator::AggOp(agg_op), sparseSafe);
-       
+                                                                               
                                                                        
SpoofOperator::AggOp(agg_op), sparseSafe);
+
        return compile_spoof_operator<std::unique_ptr<SpoofCellwiseOp>>(jenv, 
jobj, ctx, name, src, std::move(op));
 }
 
-
 [[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_hops_codegen_cplan_CNodeRow_compile_1nvrtc
-               (JNIEnv *jenv, jobject jobj, jlong ctx, jstring name, jstring 
src, jint type, jint const_dim2,
-                               jint num_vectors, jboolean TB1) {
-       
+       (JNIEnv *jenv, jobject jobj, jlong ctx, jstring name, jstring src, jint 
type, jint const_dim2,
+        jint num_vectors, jboolean TB1) {
        std::unique_ptr<SpoofRowwiseOp> op = 
std::make_unique<SpoofRowwiseOp>(SpoofOperator::RowType(type), TB1,
-                       num_vectors, const_dim2);
+                                                                               
                                                                  num_vectors, 
const_dim2);
        return compile_spoof_operator<std::unique_ptr<SpoofRowwiseOp>>(jenv, 
jobj, ctx, name, src, std::move(op));
 }
 
 
 template<typename T, typename TEMPLATE>
-int launch_spoof_operator(JNIEnv *jenv, [[maybe_unused]] jclass jobj, jlong 
_ctx, jlongArray _meta, jlongArray in,
-               jlongArray _sides, jlongArray out, jlong _scalars) {
-       std::string operator_name("unknown");
+int launch_spoof_operator([[maybe_unused]] JNIEnv *jenv, [[maybe_unused]] 
jclass jobj, jlong _ctx) {
+       std::string operator_name("launch_spoof_operator jni-bridge");
        try {
                // retrieve data handles from JVM
-               auto *metacast = reinterpret_cast<size_t *>(GET_ARRAY(jenv, 
_meta));
                auto *ctx = reinterpret_cast<SpoofCUDAContext *>(_ctx);
-               auto *inputs = reinterpret_cast<size_t *>(GET_ARRAY(jenv, in));
-               auto *sides = reinterpret_cast<size_t *>(GET_ARRAY(jenv, 
_sides));
-               auto *output = reinterpret_cast<size_t *>(GET_ARRAY(jenv, out));
-//             auto *scalars = reinterpret_cast<T *>(GET_ARRAY(jenv, 
_scalars));
-               auto *scalars = reinterpret_cast<T *>(_scalars);
-               LaunchMetadata<size_t> meta(metacast);
-               
+
+#ifndef NDEBUG
+               uint32_t opID = 
*reinterpret_cast<uint32_t*>(&ctx->staging_buffer[sizeof(uint32_t)]);
                // this implicitly checks if op exists
-               operator_name = ctx->getOperatorName(meta.opID);
-               
-               // wrap/cast inputs
-               std::vector<Matrix<T>> mats_in;
-               for(auto i = 0ul; i < meta.num_inputs; i+=meta.entry_size)
-                       mats_in.emplace_back(&inputs[i]);
-               
-               // wrap/cast sides
-               std::vector<Matrix<T>> mats_sides;
-               for(auto i = 0ul; i < meta.num_sides; i+=meta.entry_size)
-                       mats_sides.emplace_back(&sides[i]);
-               
-               // wrap/cast output
-               Matrix<T> mat_out(output);
-               
-               // wrap/cast scalars
-//             std::unique_ptr<Matrix<T>> mat_scalars = scalars == nullptr ? 0 
: std::make_unique<Matrix<T>>(scalars);
-               
+               operator_name = ctx->getOperatorName(opID);
+               std::cout << "executing op=" << operator_name << " id=" << opID 
<< std::endl;
+#endif
                // transfers resource pointers to GPU and calls op->exec()
-               ctx->launch<T, TEMPLATE>(meta.opID, mats_in, mats_sides, 
mat_out, scalars, meta.grix);
-               
-               // release data handles from JVM
-               RELEASE_ARRAY(jenv, _meta, metacast);
-               RELEASE_ARRAY(jenv, in, inputs);
-               RELEASE_ARRAY(jenv, _sides, sides);
-               RELEASE_ARRAY(jenv, out, output);
-//             RELEASE_ARRAY(jenv, _scalars, scalars);
-               
+               ctx->launch<T, TEMPLATE>();
+
                return 0;
        }
-       catch (std::exception& e) {
+       catch (std::exception &e) {
                printException(operator_name, e);
        }
        catch (...) {
@@ -171,26 +156,26 @@ int launch_spoof_operator(JNIEnv *jenv, [[maybe_unused]] 
jclass jobj, jlong _ctx
        return -1;
 }
 
-[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f
-               (JNIEnv *jenv, jclass jobj, jlong ctx, jlongArray meta, 
jlongArray in, jlongArray sides, jlongArray out,
-                jlong scalars) {
-       return launch_spoof_operator<float, SpoofCellwise<float>>(jenv, jobj, 
ctx, meta, in, sides, out, scalars);
-}
 
 [[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1d
-               (JNIEnv *jenv, jclass jobj, jlong ctx, jlongArray meta, 
jlongArray in, jlongArray sides, jlongArray out,
-                jlong scalars) {
-       return launch_spoof_operator<double, SpoofCellwise<double>>(jenv, jobj, 
ctx, meta, in, sides, out, scalars);
+       (JNIEnv *jenv, jclass jobj, jlong ctx) {
+       return launch_spoof_operator<double, SpoofCellwise<double>>(jenv, jobj, 
ctx);
 }
 
-[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f
-               (JNIEnv *jenv, jclass jobj, jlong ctx, jlongArray meta, 
jlongArray in, jlongArray sides, jlongArray out,
-                jlong scalars) {
-       return launch_spoof_operator<float, SpoofRowwise<float>>(jenv, jobj, 
ctx, meta, in, sides, out, scalars);
+
+[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f
+       (JNIEnv *jenv, jclass jobj, jlong ctx) {
+       return launch_spoof_operator<double, SpoofCellwise<double>>(jenv, jobj, 
ctx);
 }
 
+
 [[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1d
-               (JNIEnv *jenv, jclass jobj, jlong ctx, jlongArray meta, 
jlongArray in, jlongArray sides, jlongArray out,
-                jlong scalars) {
-       return launch_spoof_operator<double, SpoofRowwise<double>>(jenv, jobj, 
ctx, meta, in, sides, out, scalars);
+       (JNIEnv *jenv, jclass jobj, jlong ctx) {
+       return launch_spoof_operator<double, SpoofRowwise<double>>(jenv, jobj, 
ctx);
 }
+
+
+[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f
+       (JNIEnv *jenv, jclass jobj, jlong ctx) {
+       return launch_spoof_operator<double, SpoofRowwise<double>>(jenv, jobj, 
ctx);
+}
\ No newline at end of file
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.h 
b/src/main/cuda/spoof-launcher/jni_bridge.h
index 1e9ef20..0f25936 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.h
+++ b/src/main/cuda/spoof-launcher/jni_bridge.h
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-/* DO NOT EDIT THIS FILE - it is machine generated */
+/* DO EDIT THIS FILE - it is not machine generated */
 
 #pragma once
 #ifndef JNI_BRIDGE_H
@@ -32,20 +32,26 @@ extern "C" {
 /*
  * Class:     org_apache_sysds_hops_codegen_SpoofCompiler
  * Method:    initialize_cuda_context
- * Signature: (I)J
+ * Signature: (ILjava/lang/String;)J
  */
-[[maybe_unused]] JNIEXPORT jlong JNICALL
-Java_org_apache_sysds_hops_codegen_SpoofCompiler_initialize_1cuda_1context(
-    JNIEnv *, [[maybe_unused]] jobject, jint, jstring);
+[[maybe_unused]] JNIEXPORT jlong JNICALL 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_initialize_1cuda_1context
+       (JNIEnv *, [[maybe_unused]] jobject, jint, jstring);
 
 /*
  * Class:     org_apache_sysds_hops_codegen_SpoofCompiler
  * Method:    destroy_cuda_context
  * Signature: (JI)V
  */
-[[maybe_unused]] JNIEXPORT void JNICALL
-Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context(
-               [[maybe_unused]] JNIEnv *, [[maybe_unused]] jobject, jlong, 
jint);
+[[maybe_unused]] JNIEXPORT void JNICALL 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context
+       ([[maybe_unused]] JNIEnv *, [[maybe_unused]] jobject, jlong, jint);
+
+/*
+ * Class:     org_apache_sysds_runtime_codegen_SpoofOperator
+ * Method:    getNativeStagingBuffer
+ * Signature: (Ljcuda/Pointer;JI)I
+ */
+[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofOperator_getNativeStagingBuffer
+       (JNIEnv *, jclass, jobject, jlong, jint);
 
 /*
  * Class:     org_apache_sysds_hops_codegen_cplan_CNodeCell
@@ -64,37 +70,36 @@ 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context(
                (JNIEnv *, [[maybe_unused]] jobject, jlong, jstring, jstring, 
jint, jint, jint, jboolean);
 
 /*
- * Class:     org_apache_sysds_runtime_codegen_SpoofCUDACellwiseOperator
- * Method:    execute_f
- * Signature: (J[J[J[J[JJ)I
- */
-[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f
-               (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray, 
jlongArray, jlong);
-
-/*
- * Class:     org_apache_sysds_runtime_codegen_SpoofCUDACellwiseOperator
+ * Class:     org_apache_sysds_runtime_codegen_SpoofCUDACellwise
  * Method:    execute_d
- * Signature: (J[J[J[J[JJ)I
+ * Signature: (J)I
  */
 [[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1d
-               (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray, 
jlongArray, jlong);
-
+       (JNIEnv *, jclass, jlong);
 
 /*
- * Class:     org_apache_sysds_runtime_codegen_SpoofCUDARowwise
+ * Class:     org_apache_sysds_runtime_codegen_SpoofCUDACellwise
  * Method:    execute_f
- * Signature: (J[J[J[J[JJ)I
+ * Signature: (J)I
  */
-[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f
-               (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray, 
jlongArray, jlong);
+[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f
+       (JNIEnv *, jclass, jlong);
 
 /*
  * Class:     org_apache_sysds_runtime_codegen_SpoofCUDARowwise
  * Method:    execute_d
- * Signature: (J[J[J[J[JJ)I
+ * Signature: (J)I
  */
 [[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1d
-               (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray, 
jlongArray, jlong);
+       (JNIEnv *, jclass, jlong);
+
+/*
+ * Class:     org_apache_sysds_runtime_codegen_SpoofCUDARowwise
+ * Method:    execute_f
+ * Signature: (J)I
+ */
+[[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f
+       (JNIEnv *, jclass, jlong);
 
 #ifdef __cplusplus
 }
diff --git a/src/main/cuda/spoof/cellwise.cu b/src/main/cuda/spoof/cellwise.cu
index 951f709..6751b1d 100644
--- a/src/main/cuda/spoof/cellwise.cu
+++ b/src/main/cuda/spoof/cellwise.cu
@@ -42,7 +42,8 @@ struct SpoofCellwiseOp {
                uint32_t _grix;
 
        SpoofCellwiseOp(Matrix<T>* _A, Matrix<T>* _B, Matrix<T>* _C, T* 
scalars, uint32_t grix) :
-                       n(_A->cols), scalars(scalars), _grix(grix) {
+               n(_A->cols), scalars(scalars), _grix(grix)
+       {
                A.init(_A);
                c.init(_C);
                alen = A.row_len(grix);
diff --git a/src/main/cuda/spoof/rowwise.cu b/src/main/cuda/spoof/rowwise.cu
index 7d80d54..b31ce0c 100644
--- a/src/main/cuda/spoof/rowwise.cu
+++ b/src/main/cuda/spoof/rowwise.cu
@@ -55,6 +55,8 @@ struct SpoofRowwiseOp //%HAS_TEMP_VECT%
 
        __device__  __forceinline__ void exec_dense(uint32_t ai, uint32_t ci, 
uint32_t rix) {
 //%BODY_dense%
+               if (debug_row() && debug_thread())
+                       printf("c[0]=%4.3f\n", c.vals(0)[0]);
        }
 
        __device__  __forceinline__ void exec_sparse(uint32_t ai, uint32_t ci, 
uint32_t rix) {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
index 47a7178..f803000 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
@@ -156,10 +156,12 @@ public class CNodeCell extends CNodeTpl
                if(tmpDense.contains("grix"))
                        tmp = tmp.replace("//%NEED_GRIX%", "\t\tuint32_t 
grix=_grix + rix;");
                else
-                       tmp = tmp.replace("//%NEED_GRIX%", "");
-               tmp = tmp.replace("//%NEED_RIX%", "");
-               tmp = tmp.replace("//%NEED_CIX%", "");
-               
+                       tmp = tmp.replace("//%NEED_GRIX%\n", "");
+
+               // remove empty lines
+//             if(!api.isJava())
+//                     tmpDense = tmpDense.replaceAll("(?m)^[ \t]*\r?\n", "");
+
                tmp = tmp.replace("%BODY_dense%", tmpDense);
                
                //Return last TMP. Square it for CUDA+SUM_SQ
@@ -171,10 +173,7 @@ public class CNodeCell extends CNodeTpl
                tmp = tmp.replace("%AGG_OP_NAME%", (_aggOp != null) ? "AggOp." 
+ _aggOp.name() : "null");
                tmp = tmp.replace("%SPARSE_SAFE%", 
String.valueOf(isSparseSafe()));
                tmp = tmp.replace("%SEQ%", String.valueOf(containsSeq()));
-               
-               // maybe empty lines
-               //tmp = tmp.replaceAll("(?m)^[ \t]*\r?\n", "");
-               
+
                if(api == GeneratorAPI.CUDA) {
                        // ToDo: initial_value is misused to pass VT (values 
per thread) to no_agg operator
                        String agg_op = "IdentityOp";
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
index 0e21017..f2405d5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
@@ -191,7 +191,7 @@ public class Unary extends CodeTemplate {
                                case MULT2:
                                        return "        T %TMP% = %IN1% + 
%IN1%;\n";
                                case ABS:
-                                       return "        T %TMP% = 
fabs(%IN1%);\n";
+                                       return "\t\tT %TMP% = fabs(%IN1%);\n";
                                case SIN:
                                        return "        T %TMP% = 
sin(%IN1%);\n";
                                case COS:
@@ -217,7 +217,7 @@ public class Unary extends CodeTemplate {
                                case LOG:
                                        return "                T %TMP% = 
log(%IN1%);\n";
                                case ROUND:
-                                       return "        T %TMP% = 
round(%IN1%);\n";
+                                       return "\t\tT %TMP% = round(%IN1%);\n";
                                case CEIL:
                                        return "        T %TMP% = 
ceil(%IN1%);\n";
                                case FLOOR:
diff --git 
a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
index 0510941..5a32114 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
@@ -23,11 +23,11 @@ import jcuda.Pointer;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
 
@@ -38,101 +38,75 @@ public class SpoofCUDACellwise extends SpoofCellwise 
implements SpoofCUDAOperato
        private static final Log LOG = 
LogFactory.getLog(SpoofCUDACellwise.class.getName());
        private final int ID;
        private final PrecisionProxy call;
-       private Pointer ptr;
        private final SpoofCellwise fallback_java_op;
+       private final long ctx;
        
        public SpoofCUDACellwise(CellType type, boolean sparseSafe, boolean 
containsSeq, AggOp aggOp, int id,
-                       PrecisionProxy ep, SpoofCellwise fallback) {
+               PrecisionProxy ep, SpoofCellwise fallback)
+       {
                super(type, sparseSafe, containsSeq, aggOp);
                ID = id;
                call = ep;
-               ptr = null;
                fallback_java_op = fallback;
+               ctx = 
SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA);
        }
        
        @Override
-       public ScalarObject execute(ExecutionContext ec, 
ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects) {
+       public ScalarObject execute(ExecutionContext ec, 
ArrayList<MatrixObject> inputs,
+               ArrayList<ScalarObject> scalarObjects)
+       {
                double[] result = new double[1];
-               // ToDo: this is a temporary "solution" before perf opt
-               int NT=256;
-               long N = inputs.get(0).getNumRows() * 
inputs.get(0).getNumColumns();
-               long num_blocks = ((N + NT * 2 - 1) / (NT * 2));
-               Pointer ptr = ec.getGPUContext(0).allocate(getName(), 
LibMatrixCUDA.sizeOfDataType * num_blocks);
-               long[] out = {1,1,1, 0, 0, GPUObject.getPointerAddress(ptr)};
-               int offset = 1;
-               if(call.exec(ec, this, ID, prepareInputPointers(ec, inputs, 
offset), 
-                       prepareSideInputPointers(ec, inputs, offset, false), 
out, scalarObjects, 0 ) != 0) {
-                       LOG.error("SpoofCUDA " + getSpoofType() + " operator 
failed to execute. Trying Java fallback.\n");
-                       // ToDo: java fallback
+               Pointer[] ptr = new Pointer[1];
+               packDataForTransfer(ec, inputs, scalarObjects, null, 1, ID, 
0,false, ptr);
+               if(NotEmpty(inputs.get(0).getGPUObject(ec.getGPUContext(0)))) {
+                       if(call.exec(this) != 0)
+                               LOG.error("SpoofCUDA " + getSpoofType() + " 
operator " + ID + " failed to execute!\n");
                }
-               
LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr, 
result, getName(), false);
-               
+               
LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr[0], 
result, getName(), false);
+               ec.getGPUContext(0).cudaFreeHelper(getSpoofType(), ptr[0], 
DMLScript.EAGER_CUDA_FREE);
                return new DoubleObject(result[0]);
        }
        
        @Override public String getName() {
                return getSpoofType();
        }
-       
-       @Override public void setScalarPtr(Pointer _ptr) {
-               ptr = _ptr;
-       }
-       
-       @Override public Pointer getScalarPtr() {
-               return ptr;
-       }
-       
-       @Override public void releaseScalarGPUMemory(ExecutionContext ec) {
-               if(ptr != null) {
-                       ec.getGPUContext(0).cudaFreeHelper(getSpoofType(), ptr, 
DMLScript.EAGER_CUDA_FREE);
-                       ptr = null;
-               }
-       }
-       
+
        @Override
-       public MatrixObject execute(ExecutionContext ec, 
ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects,
-                       String outputName) {
-               
+       public MatrixObject execute(ExecutionContext ec, 
ArrayList<MatrixObject> inputs,
+               ArrayList<ScalarObject> scalarObjects, String outputName)
+       {
                long out_rows = ec.getMatrixObject(outputName).getNumRows();
                long out_cols = ec.getMatrixObject(outputName).getNumColumns();
-               MatrixObject a = inputs.get(0);
-               GPUContext gctx = ec.getGPUContext(0);
-               int m = (int) a.getNumRows();
-               int n = (int) a.getNumColumns();
-               double[] scalars = prepInputScalars(scalarObjects);
+
                if(_type == CellType.COL_AGG)
                        out_rows = 1;
                else if(_type == SpoofCellwise.CellType.ROW_AGG)
                        out_cols = 1;
-               
+
+               double[] scalars = prepInputScalars(scalarObjects);
                boolean sparseSafe = isSparseSafe() || ((inputs.size() < 2) && 
-                               genexec( 0, new SideInput[0], scalars, m, n, 0, 
0 ) == 0);
-               
-//             ec.setMetaData(outputName, out_rows, out_cols);
-               GPUObject g = a.getGPUObject(gctx);
-               boolean sparseOut = _type == CellType.NO_AGG && sparseSafe && 
g.isSparse();
-               
-               long nnz = g.getNnz("spoofCUDA" + getSpoofType(), false);
-               if(sparseOut)
-                       LOG.warn("sparse out");
+                               genexec( 0, new SideInput[0], scalars, (int) 
inputs.get(0).getNumRows(),
+                                       (int) inputs.get(0).getNumColumns(), 0, 
0 ) == 0);
+
+               GPUObject in_obj = 
inputs.get(0).getGPUObject(ec.getGPUContext(0));
+               boolean sparseOut = _type == CellType.NO_AGG && sparseSafe && 
in_obj.isSparse();
+               long nnz = in_obj.getNnz("spoofCUDA" + getSpoofType(), false);
                MatrixObject out_obj = sparseOut ?
                                
(ec.getSparseMatrixOutputForGPUInstruction(outputName, out_rows, out_cols, 
(isSparseSafe() && nnz > 0) ?
                                                nnz : out_rows * 
out_cols).getKey()) :
                                
(ec.getDenseMatrixOutputForGPUInstruction(outputName, out_rows, 
out_cols).getKey());
-               
-               int offset = 1;
-               if(!inputIsEmpty(a.getGPUObject(gctx)) || !sparseSafe) {
-                       if(call.exec(ec, this, ID, prepareInputPointers(ec, 
inputs, offset), prepareSideInputPointers(ec, inputs, offset, false),
-                               prepareOutputPointers(ec, out_obj, sparseOut), 
scalarObjects, 0) != 0) {
-                               LOG.error("SpoofCUDA " + getSpoofType() + " 
operator failed to execute. Trying Java fallback.(ToDo)\n");
-                               // ToDo: java fallback
-                       }
+
+               packDataForTransfer(ec, inputs, scalarObjects, out_obj, 1, ID, 
0,false, null);
+               if(NotEmpty(in_obj) || !sparseSafe) {
+                       if(call.exec(this) != 0)
+                               LOG.error("SpoofCUDA " + getSpoofType() + " 
operator " + ID + " failed to execute!\n");
                }
                return out_obj;
        }
        
-       private static boolean inputIsEmpty(GPUObject g) {
-               return g.getDensePointer() == null && 
g.getSparseMatrixCudaPointer() == null;
+       private static boolean NotEmpty(GPUObject g) {
+               // ToDo: check if that check is sufficient
+               return g.getDensePointer() != null || 
g.getSparseMatrixCudaPointer() != null;
        }
        
        // used to determine sparse safety
@@ -140,15 +114,11 @@ public class SpoofCUDACellwise extends SpoofCellwise 
implements SpoofCUDAOperato
        protected double genexec(double a, SideInput[] b, double[] scalars, int 
m, int n, long gix, int rix, int cix) {
                return fallback_java_op.genexec(a, b, scalars, m, n, 0, 0, 0);
        }
-       
-       public int execute_sp(long ctx, long[] meta, long[] in, long[] sides, 
long[] out, long scalars) {
-               return execute_f(ctx, meta, in, sides, out, scalars);   
-       }
-       
-       public int execute_dp(long ctx, long[] meta, long[] in, long[] sides, 
long[] out, long scalars) {
-               return execute_d(ctx, meta, in, sides, out, scalars);
-       }
-       
-       public static native int execute_f(long ctx, long[] meta, long[] in, 
long[] sides, long[] out, long scalars);
-       public static native int execute_d(long ctx, long[] meta, long[] in, 
long[] sides, long[] out, long scalars);
+
+       public int execute_dp(long ctx) { return execute_d(ctx); }
+       public int execute_sp(long ctx) { return execute_d(ctx); }
+       public long getContext() { return ctx; }
+
+       public static native int execute_d(long ctx);
+       public static native int execute_s(long ctx);
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDAOperator.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDAOperator.java
index 118f97a..02fbf96 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDAOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDAOperator.java
@@ -19,7 +19,10 @@
 
 package org.apache.sysds.runtime.codegen;
 
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
 import jcuda.Pointer;
+
 import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -27,146 +30,112 @@ import 
org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
 
-import java.util.ArrayList;
-
+import static jcuda.runtime.cudaError.cudaSuccess;
 import static 
org.apache.sysds.runtime.matrix.data.LibMatrixCUDA.sizeOfDataType;
 
 public interface SpoofCUDAOperator  {
-       int JNI_MAT_ENTRY_SIZE = 6;
+       // these two constants have equivalences in native code:
+       int JNI_MAT_ENTRY_SIZE = 40;
+       int TRANSFERRED_DATA_HEADER_SIZE = 32;
+
        abstract class PrecisionProxy {
                protected final long ctx;
                
                public PrecisionProxy() { ctx = 
SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA);     }
                
-               public abstract int exec(ExecutionContext ec, SpoofCUDAOperator 
op, int opID, long[] in, long[] sides, long[] out,
-                               ArrayList<ScalarObject> scalarObjects, long 
grix);
-               
-               protected Pointer transferScalars(ExecutionContext ec, 
SpoofCUDAOperator op, int sizeOfDataType,
-                               ArrayList<ScalarObject> scalarObjects) {
-                       double[] s = 
SpoofOperator.prepInputScalars(scalarObjects);
-                       Pointer ptr = 
ec.getGPUContext(0).allocate(op.getName(), (long) scalarObjects.size() * 
sizeOfDataType);
-                       
LibMatrixCUDA.cudaSupportFunctions.hostToDevice(ec.getGPUContext(0), s, ptr, 
op.getName());
-                       return ptr;
-               }
+               public abstract int exec(SpoofCUDAOperator op);
        }
        
        String getName();
-       
-       void setScalarPtr(Pointer ptr);
-       
-       Pointer getScalarPtr();
-       
-       void releaseScalarGPUMemory(ExecutionContext ec);
-       
-       default long [] prepareInputPointers(ExecutionContext ec, 
ArrayList<MatrixObject> inputs, int offset) {
-               long [] in = new long[offset * JNI_MAT_ENTRY_SIZE];
-               for(int i = 0; i < offset; i++) {
-                       int j = i  * JNI_MAT_ENTRY_SIZE;
-                       
-                       
if(inputs.get(i).getGPUObject(ec.getGPUContext(0)).isSparse()) {
-                               in[j] = 
ec.getGPUSparsePointerAddress(inputs.get(i)).nnz;
-                               in[j + 1] = inputs.get(i).getNumRows();
-                               in[j + 2] = inputs.get(i).getNumColumns();
-                               in[j + 3] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).rowPtr);
-                               in[j + 4] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).colInd);
-                               in[j + 5] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).val);
-                       }
-                       else {
-                               in[j] = inputs.get(i).getNnz();
-                               in[j + 1] = inputs.get(i).getNumRows();
-                               in[j + 2] = inputs.get(i).getNumColumns();
-                               in[j + 5] = 
ec.getGPUDensePointerAddress(inputs.get(i));
-                       }
-               }
-               return in;
+
+       default void writeMatrixDescriptorToBuffer(ByteBuffer dst, int rows, 
int cols, long row_ptr,
+               long col_idx_ptr, long data_ptr, long nnz)
+       {
+               dst.putLong(nnz);
+               dst.putInt(rows);
+               dst.putInt(cols);
+               dst.putLong(row_ptr);
+               dst.putLong(col_idx_ptr);
+               dst.putLong(data_ptr);
        }
-       
-       default long [] prepareSideInputPointers(ExecutionContext ec, 
ArrayList<MatrixObject> inputs, int offset, boolean tB1) {
-               long[] sides = new long[(inputs.size() - offset) * 
JNI_MAT_ENTRY_SIZE];
-               for(int i = offset; i < inputs.size(); i++) {
-                       int j = (i - offset)  * JNI_MAT_ENTRY_SIZE;
-                       
if(inputs.get(i).getGPUObject(ec.getGPUContext(0)).isSparse()) {
-                               sides[j] = 
ec.getGPUSparsePointerAddress(inputs.get(i)).nnz;
-                               sides[j + 1] = inputs.get(i).getNumRows();
-                               sides[j + 2] = inputs.get(i).getNumColumns();
-                               sides[j + 3] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).rowPtr);
-                               sides[j + 4] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).colInd);
-                               sides[j + 5] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).val);
-                       }
-                       else {
-                               if(tB1 && j == 0) {
-                                       long rows = inputs.get(i).getNumRows();
-                                       long cols = 
inputs.get(i).getNumColumns();
-                                       Pointer b1 = 
inputs.get(i).getGPUObject(ec.getGPUContext(0)).getDensePointer();
-                                       Pointer ptr = 
ec.getGPUContext(0).allocate(getName(), rows * cols * sizeOfDataType);
-                                       
-//                                     double[] tmp1 = new double[(int) (rows 
* cols)];
-//                                     
LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), b1, tmp1, 
getName(), false);
-//
-//                                     System.out.println("Mat before 
transpose: rows=" + rows + " cols=" + cols + "\n");
-//                                     for(int m = 0; m < rows; m++) {
-//                                             StringBuilder sb = new 
StringBuilder();
-//                                             for(int n = 0; n < cols; n++)
-//                                                     sb.append(" " + 
tmp1[(int) (cols * m + n)]);
-//                                             
System.out.println(sb.toString());
-//                                     }
-                                       
-                                       LibMatrixCUDA.denseTranspose(ec, 
ec.getGPUContext(0), getName(),
-                                               b1, ptr, rows, cols);
-                                       
-//                                     double[] tmp2 = new double[(int) (rows 
* cols)];
-//                                     
LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr, tmp2, 
getName(), false);
-//
-//                                     System.out.println("Mat after 
transpose: rows=" + cols + " cols=" + rows + "\n");
-//                                     for(int m = 0; m < cols; m++) {
-//                                             StringBuilder sb = new 
StringBuilder();
-//                                             for(int n = 0; n < rows; n++)
-//                                                     sb.append(" " + 
tmp2[(int) (rows * m + n)]);
-//                                             
System.out.println(sb.toString());
-//                                     }
 
-                                       sides[j] = inputs.get(i).getNnz();
-                                       sides[j + 1] = cols;
-                                       sides[j + 2] = rows;
-                                       sides[j + 5] = 
GPUObject.getPointerAddress(ptr);
-                                       
-                               } else {
-                                       sides[j] = inputs.get(i).getNnz();
-                                       sides[j + 1] = 
inputs.get(i).getNumRows();
-                                       sides[j + 2] = 
inputs.get(i).getNumColumns();
-                                       sides[j + 5] = 
ec.getGPUDensePointerAddress(inputs.get(i));
-                               }
+       default void prepareMatrixPointers(ByteBuffer buf, ExecutionContext ec, 
MatrixObject mo, boolean tB1) {
+               if(mo.getGPUObject(ec.getGPUContext(0)).isSparse()) {
+                       writeMatrixDescriptorToBuffer(buf, 
(int)mo.getNumRows(), (int)mo.getNumColumns(),
+                               
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(mo).rowPtr),
+                                       
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(mo).colInd),
+                                               
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(mo).val),
+                                                       
ec.getGPUSparsePointerAddress(mo).nnz);
+               }
+               else {
+                       if(tB1) {
+                               int rows = (int)mo.getNumRows();
+                               int cols = (int)mo.getNumColumns();
+                               Pointer b1 = 
mo.getGPUObject(ec.getGPUContext(0)).getDensePointer();
+                               Pointer ptr = 
ec.getGPUContext(0).allocate(getName(), (long) rows * cols * sizeOfDataType);
+                               LibMatrixCUDA.denseTranspose(ec, 
ec.getGPUContext(0), getName(), b1, ptr, rows, cols);
+                               writeMatrixDescriptorToBuffer(buf, rows, cols, 
0, 0, GPUObject.getPointerAddress(ptr), mo.getNnz());
+                       } else {
+                               writeMatrixDescriptorToBuffer(buf, 
(int)mo.getNumRows(), (int)mo.getNumColumns(), 0, 0,
+                                       ec.getGPUDensePointerAddress(mo), 
mo.getNnz());
                        }
                }
-               return sides;
        }
-       
-       default long[] prepareOutputPointers(ExecutionContext ec, MatrixObject 
output, boolean sparseOut) {
-               long[] out = {0,0,0,0,0,0};
 
-               if(sparseOut) {
-                       out[0] = ec.getGPUSparsePointerAddress(output).nnz;
-                       out[1] = output.getNumRows();
-                       out[2] = output.getNumColumns();
-                       out[3] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(output).rowPtr);
-                       out[4] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(output).colInd);
-                       out[5] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(output).val);
+       default void packDataForTransfer(ExecutionContext ec, 
ArrayList<MatrixObject> inputs,
+               ArrayList<ScalarObject> scalarObjects, MatrixObject out_obj, 
int num_inputs, int ID, long grix, boolean tB1,
+                       Pointer[] ptr)
+       {
+               int op_data_size = (inputs.size() + 1) * JNI_MAT_ENTRY_SIZE + 
scalarObjects.size() * Double.BYTES + TRANSFERRED_DATA_HEADER_SIZE;
+               Pointer staging = new Pointer();
+               if(SpoofOperator.getNativeStagingBuffer(staging, 
this.getContext(), op_data_size) != cudaSuccess)
+                       throw new RuntimeException("Failed to get native 
staging buffer from spoof operator");
+               ByteBuffer buf = staging.getByteBuffer();
+               buf.putInt(op_data_size);
+               buf.putInt(ID);
+               buf.putInt((int)grix);
+               buf.putInt(num_inputs);
+               buf.putInt(inputs.size() - num_inputs);
+               buf.putInt(out_obj == null ? 0 : 1);
+               buf.putInt(scalarObjects.size());
+               buf.putInt(-1); // padding
+
+               // copy input & side input pointers
+               for(int i=0; i < inputs.size(); i++) {
+                       if(i == num_inputs)
+                               prepareMatrixPointers(buf, ec, inputs.get(i), 
tB1);
+                       else
+                               prepareMatrixPointers(buf, ec, inputs.get(i), 
false);
                }
-               else {
-                       out[0] = output.getNnz();
-                       out[1] = output.getNumRows();
-                       out[2] = output.getNumColumns();
-                       out[5] = ec.getGPUDensePointerAddress(output);
+
+               // copy output pointers or allocate buffer for reduction
+               if(out_obj == null) {
+                       long num_blocks = 1;
+                       if(this instanceof SpoofCUDACellwise) {
+                               int NT = 256;
+                               long N = inputs.get(0).getNumRows() * 
inputs.get(0).getNumColumns();
+                               num_blocks = ((N + NT * 2 - 1) / (NT * 2));
+                       }
+                       ptr[0] = ec.getGPUContext(0).allocate(getName(), 
LibMatrixCUDA.sizeOfDataType * num_blocks);
+                       writeMatrixDescriptorToBuffer(buf, 1, 1, 0, 0, 
GPUObject.getPointerAddress(ptr[0]), 1);
+               }
+               else {
+                       prepareMatrixPointers(buf, ec, out_obj, false);
+               }
+
+               // copy scalar values (no pointers)
+               for(ScalarObject scalarObject : scalarObjects) {
+                       buf.putDouble(scalarObject.getDoubleValue());
                }
-               return out;
        }
-       
-       MatrixObject execute(ExecutionContext ec, ArrayList<MatrixObject> 
inputs, 
+
+       MatrixObject execute(ExecutionContext ec, ArrayList<MatrixObject> 
inputs,
                        ArrayList<ScalarObject> scalarObjects, String 
outputName);
        
        ScalarObject execute(ExecutionContext ec, ArrayList<MatrixObject> 
inputs,
                ArrayList<ScalarObject> scalarObjects);
-       
-       int execute_sp(long ctx, long[] meta, long[] in, long[] sides, long[] 
out, long scalars);
-       int execute_dp(long ctx, long[] meta, long[] in, long[] sides, long[] 
out, long scalars);
+
+       int execute_dp(long ctx);
+       int execute_sp(long ctx);
+       long getContext();
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
index b5686c7..2632fae 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
@@ -23,11 +23,11 @@ import jcuda.Pointer;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
 
 import java.util.ArrayList;
@@ -37,55 +37,40 @@ public class SpoofCUDARowwise extends SpoofRowwise 
implements SpoofCUDAOperator
        private static final Log LOG = 
LogFactory.getLog(SpoofCUDARowwise.class.getName());
        private final int ID;
        private final PrecisionProxy call;
-       private Pointer ptr;
+       private final long ctx;
        
        public SpoofCUDARowwise(RowType type,  long constDim2, boolean tB1, int 
reqVectMem, int id,
-               PrecisionProxy ep) {
+               PrecisionProxy ep)
+       {
                super(type, constDim2, tB1, reqVectMem);
                ID = id;
                call = ep;
-               ptr = null;
+               ctx = 
SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA);
        }
        
        @Override public String getName() {
                return getSpoofType();
        }
-       
-       @Override public void setScalarPtr(Pointer _ptr) {
-               ptr = _ptr;
-       }
-       
-       @Override public Pointer getScalarPtr() {
-               return ptr;
-       }
-       
-       @Override public void releaseScalarGPUMemory(ExecutionContext ec) {
-               if(ptr != null) {
-                       ec.getGPUContext(0).cudaFreeHelper(getSpoofType(), ptr, 
DMLScript.EAGER_CUDA_FREE);
-                       ptr = null;
-               }
-       }
-       
+
        @Override
        public ScalarObject execute(ExecutionContext ec, 
ArrayList<MatrixObject> inputs,
-               ArrayList<ScalarObject> scalarObjects) {
+               ArrayList<ScalarObject> scalarObjects)
+       {
                double[] result = new double[1];
-               Pointer ptr = ec.getGPUContext(0).allocate(getName(), 
LibMatrixCUDA.sizeOfDataType);
-               long[] out = {1,1,1, 0, 0, GPUObject.getPointerAddress(ptr)};
-               int offset = 1;
-               if(call.exec(ec, this, ID, prepareInputPointers(ec, inputs, 
offset), prepareSideInputPointers(ec, inputs, offset, _tB1),
-                               out, scalarObjects, 0) != 0) {
-                       LOG.error("SpoofCUDA " + getSpoofType() + " operator 
failed to execute. Trying Java fallback.\n");
-                       // ToDo: java fallback
-               }
-               
LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr, 
result, getName(), false);
+               Pointer[] ptr = new Pointer[1];
+               packDataForTransfer(ec, inputs, scalarObjects, null, 1, ID, 
0,_tB1, ptr);
+               if(call.exec(this) != 0)
+                       LOG.error("SpoofCUDA " + getSpoofType() + " operator " 
+ ID + " failed to execute!\n");
+
+               
LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr[0], 
result, getName(), false);
+               ec.getGPUContext(0).cudaFreeHelper(getSpoofType(), ptr[0], 
DMLScript.EAGER_CUDA_FREE);
                return new DoubleObject(result[0]);
        }
        
        @Override
        public MatrixObject execute(ExecutionContext ec, 
ArrayList<MatrixObject> inputs,
-               ArrayList<ScalarObject> scalarObjects, String outputName) {
-               
+               ArrayList<ScalarObject> scalarObjects, String outputName)
+       {
                int m = (int) inputs.get(0).getNumRows();
                int n = (int) inputs.get(0).getNumColumns();
                final int n2 = _type.isConstDim2(_constDim2) ? (int)_constDim2 
: _type.isRowTypeB1() ||
@@ -93,13 +78,12 @@ public class SpoofCUDARowwise extends SpoofRowwise 
implements SpoofCUDAOperator
                OutputDimensions out_dims = new OutputDimensions(m, n, n2);
                ec.setMetaData(outputName, out_dims.rows, out_dims.cols);
                MatrixObject out_obj = 
ec.getDenseMatrixOutputForGPUInstruction(outputName, out_dims.rows, 
out_dims.cols).getKey();
-               
-               int offset = 1;
-               if(call.exec(ec,this, ID, prepareInputPointers(ec, inputs, 
offset), prepareSideInputPointers(ec, inputs, 
-                               offset, _tB1), prepareOutputPointers(ec, 
out_obj, false), scalarObjects, 0) != 0) {
-                       LOG.error("SpoofCUDA " + getSpoofType() + " operator 
failed to execute. Trying Java fallback.\n");
-                       // ToDo: java fallback
-               }
+
+               packDataForTransfer(ec, inputs, scalarObjects, out_obj, 1, ID, 
0,_tB1, null);
+
+               if(call.exec(this) != 0)
+                       LOG.error("SpoofCUDA " + getSpoofType() + " operator " 
+ ID + " failed to execute!\n");
+
                return out_obj;
        }
        
@@ -110,15 +94,11 @@ public class SpoofCUDARowwise extends SpoofRowwise 
implements SpoofCUDAOperator
        // unused
        @Override protected void genexec(double[] avals, int[] aix, int ai, 
SideInput[] b, double[] scalars, double[] c,
                int ci, int alen, int n, long grix, int rix) { }
-       
-       public int execute_sp(long ctx, long[] meta, long[] in, long[] sides, 
long[] out, long scalars) {
-               return execute_f(ctx, meta, in, sides, out, scalars);
-       }
-       
-       public int execute_dp(long ctx, long[] meta, long[] in, long[] sides, 
long[] out, long scalars) {
-               return execute_d(ctx, meta, in, sides, out, scalars);
-       }
-       
-       public static native int execute_f(long ctx, long[] meta, long[] in, 
long[] sides, long[] out, long scalars);
-       public static native int execute_d(long ctx, long[] meta, long[] in, 
long[] sides, long[] out, long scalars);
+
+       public int execute_dp(long ctx) { return execute_d(ctx); }
+       public int execute_sp(long ctx) { return execute_d(ctx); }
+       public long getContext() { return ctx; }
+
+       public static native int execute_d(long ctx);
+       public static native int execute_s(long ctx);
 }
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofOperator.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofOperator.java
index fe8d932..f270091 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofOperator.java
@@ -23,6 +23,7 @@ import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
 
+import jcuda.Pointer;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -321,4 +322,6 @@ public abstract class SpoofOperator implements Serializable
                        currColPos = 0;
                }
        }
+
+       public static native int getNativeStagingBuffer(Pointer ptr, long 
context, int size);
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
index 9fd1bcf..b54d6fe 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
@@ -35,7 +35,6 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.utils.GPUStatistics;
@@ -52,21 +51,14 @@ public class SpoofCUDAInstruction extends GPUInstruction {
        public final CPOperand _out;
        
        public static class SinglePrecision extends 
SpoofCUDAOperator.PrecisionProxy {
-               public int exec(ExecutionContext ec, SpoofCUDAOperator op, int 
opID, long[] in, long[] sides, long[] out,
-                               ArrayList<ScalarObject> scalarObjects, long 
grix) {
-                       op.setScalarPtr(transferScalars(ec, op, Sizeof.FLOAT, 
scalarObjects));
-                       long[] _metadata = { opID, grix, in.length, 
sides.length, out.length, scalarObjects.size() };
-                       return op.execute_sp(ctx, _metadata, in, sides, out, 
GPUObject.getPointerAddress(op.getScalarPtr()));
+               public int exec(SpoofCUDAOperator op) {
+                       return op.execute_sp(ctx);
                }
        }
        
        public static class DoublePrecision extends 
SpoofCUDAOperator.PrecisionProxy {
-               public int exec(ExecutionContext ec, SpoofCUDAOperator op, int 
opID, long[] in, long[] sides, long[] out,
-                               ArrayList<ScalarObject> scalarObjects, long 
grix) {
-                       if(!scalarObjects.isEmpty())
-                               op.setScalarPtr(transferScalars(ec, op, 
Sizeof.DOUBLE, scalarObjects));
-                       long[] _metadata = { opID, grix, in.length, 
sides.length, out.length, scalarObjects.size() };
-                       return op.execute_dp(ctx, _metadata, in, sides, out, 
GPUObject.getPointerAddress(op.getScalarPtr()));
+               public int exec(SpoofCUDAOperator op) {
+                       return op.execute_dp(ctx);
                }
        }
        
@@ -101,7 +93,6 @@ public class SpoofCUDAInstruction extends GPUInstruction {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
 
                ArrayList<CPOperand> inlist = new ArrayList<>();
-//             Integer op_id =  
CodegenUtils.getCUDAopID(parts[2].split("\\.")[1]);
                Integer op_id = CodegenUtils.getCUDAopID(parts[2]);
                Class<?> cla = CodegenUtils.getClass(parts[2]);
                SpoofOperator fallback_java_op = 
CodegenUtils.createInstance(cla);
@@ -141,11 +132,9 @@ public class SpoofCUDAInstruction extends GPUInstruction {
                                ScalarObject out = _op.execute(ec, inputs, 
scalars);
                                ec.setScalarOutput(_out.getName(), out);
                        }
-                       
-                       _op.releaseScalarGPUMemory(ec);
                }
                catch(Exception ex) {
-                       LOG.error("SpoofCUDAInstruction: " + _op.getName() + " 
operator failed to execute. Trying Java fallback.(ToDo)\n");
+                       LOG.error("SpoofCUDAInstruction: " + _op.getName() + " 
operator failed to execute :(\n");
                        
                        throw new DMLRuntimeException(ex);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java
index efbf201..633e3fc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java
@@ -1168,4 +1168,5 @@ public class GPUObject {
        
        public static long getPointerAddress(Pointer p) {
                return (p == null) ?  0 : getPointerAddressInternal(p);
-       }}
+       }
+}

Reply via email to