This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit fc5b03de84ebd57214a69ca43f63f223dd258c89
Author: Mark Dokter <[email protected]>
AuthorDate: Wed Apr 20 13:06:05 2022 +0200

    [SYSTEMDS-3352] CUDA code gen support for connected components
    
    General cleanup and bug fixing to make components.dml run.
    Also contains improvements to handle single precision execution.
---
 src/main/cuda/headers/Matrix.h                     |  52 ++-
 src/main/cuda/headers/spoof_utils.cuh              | 215 +++++------
 src/main/cuda/headers/vector_write.cuh             |  20 +-
 src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp  |  23 +-
 src/main/cuda/spoof-launcher/SpoofRowwise.h        |   8 +-
 src/main/cuda/spoof-launcher/jni_bridge.cpp        |   4 +-
 src/main/cuda/spoof/rowwise.cu                     |   9 +-
 .../apache/sysds/hops/codegen/cplan/CNodeRow.java  |   7 +-
 .../sysds/hops/codegen/cplan/cuda/Binary.java      | 391 ++++++++-------------
 .../sysds/hops/codegen/cplan/cuda/Ternary.java     |  88 ++---
 .../sysds/hops/codegen/cplan/cuda/Unary.java       | 313 +++++++----------
 .../sysds/runtime/codegen/SpoofCUDACellwise.java   |   4 +-
 .../sysds/runtime/codegen/SpoofCUDARowwise.java    |   4 +-
 13 files changed, 462 insertions(+), 676 deletions(-)

diff --git a/src/main/cuda/headers/Matrix.h b/src/main/cuda/headers/Matrix.h
index 61ef939b83..f02a76c83c 100644
--- a/src/main/cuda/headers/Matrix.h
+++ b/src/main/cuda/headers/Matrix.h
@@ -18,8 +18,6 @@
  */
 
 #pragma once
-#ifndef SYSTEMDS_MATRIX_H
-#define SYSTEMDS_MATRIX_H
 
 using uint32_t = unsigned int;
 using int32_t = int;
@@ -43,22 +41,22 @@ struct Matrix {
 
 #ifdef __CUDACC__
 
-template<typename T>
-uint32_t bin_search(T* values, uint32_t lower, uint32_t upper, T val) {
-       upper -= 1;
-       while(lower <= (upper-1)) {
-               uint32_t idx = (lower + upper) >> 1;
-               uint32_t vi = values[idx];
-               if (vi < val)
-                       lower = idx + 1;
-               else {
-                       if (vi <= val)
-                               return idx;
-                       upper = idx - 1;
-               }
-       }
-       return upper + 1;
-}
+//template<typename T>
+//uint32_t bin_search(T* values, uint32_t lower, uint32_t upper, T val) {
+//     upper -= 1;
+//     while(lower <= (upper-1)) {
+//             uint32_t idx = (lower + upper) >> 1;
+//             uint32_t vi = values[idx];
+//             if (vi < val)
+//                     lower = idx + 1;
+//             else {
+//                     if (vi <= val)
+//                             return idx;
+//                     upper = idx - 1;
+//             }
+//     }
+//     return upper + 1;
+//}
 
 template<typename T>
 class MatrixAccessor {
@@ -68,11 +66,11 @@ class MatrixAccessor {
 public:
        MatrixAccessor() = default;
        
-       __device__ MatrixAccessor(Matrix<T>* mat) : _mat(mat) {}
+       __device__ explicit MatrixAccessor(Matrix<T>* mat) : _mat(mat) {}
        
        __device__ void init(Matrix<T>* mat) { _mat = mat; }
        
-       __device__ uint32_t& nnz() { return return _mat->row_ptr == nullptr ? 
_mat->rows * _mat->cols : _mat->nnz; }
+//     __device__ uint32_t& nnz() { return _mat->row_ptr == nullptr ? 
_mat->rows * _mat->cols : _mat->nnz; }
        __device__ uint32_t cols() { return _mat->cols; }
        __device__ uint32_t rows() { return _mat->rows; }
        
@@ -96,14 +94,14 @@ public:
        }
        
        __device__ uint32_t row_len(uint32_t rix) {
-               return _mat->row_ptr == nullptr ? row_len_dense(rix) : 
row_len_sparse(rix);
+               return _mat->row_ptr == nullptr ? _mat->rows : 
row_len_sparse(rix);
        }
        
        __device__ uint32_t* col_idxs(uint32_t rix) { return cols_sparse(rix); }
 
        __device__ void set(uint32_t r, uint32_t c, T v) { set_sparse(r,c,v); }
        
-       __device__ uint32_t* indexes() {  return _mat->row_ptr; }
+//     __device__ uint32_t* indexes() {  return _mat->row_ptr; }
        
        __device__ bool hasData() { return _mat->data != nullptr; }
 private:
@@ -127,10 +125,6 @@ private:
                return &(_mat->data[rix]);
        }
        
-       __device__ uint32_t row_len_dense(uint32_t rix) {
-               return _mat->rows;
-       }
-       
        //ToDo sparse accessors
        __device__ uint32_t len_sparse() {
                return _mat->row_ptr[_mat->rows];
@@ -145,8 +139,8 @@ private:
        }
        
        __device__ T& val_sparse_rc(uint32_t r, uint32_t c) {
-//             printf("TBI: val_sparse_rc\n");
-//             asm("trap;");
+               printf("TBI: val_sparse_rc(%d, %d)\n", r, c);
+               asm("trap;");
 
                return _mat->data[0];
        }
@@ -228,5 +222,3 @@ public:
 };
 
 #endif // __CUDACC_RTC__
-
-#endif //SYSTEMDS_MATRIX_H
diff --git a/src/main/cuda/headers/spoof_utils.cuh 
b/src/main/cuda/headers/spoof_utils.cuh
index 5d9b1012b2..8ab0fafdb2 100644
--- a/src/main/cuda/headers/spoof_utils.cuh
+++ b/src/main/cuda/headers/spoof_utils.cuh
@@ -18,8 +18,6 @@
  */
 
 #pragma once
-#ifndef SPOOF_UTILS_CUH
-#define SPOOF_UTILS_CUH
 
 #include <math_constants.h>
 #include "vector_add.cuh"
@@ -31,13 +29,8 @@ struct TempStorage;
 #include "Matrix.h"
 #include "vector_write.cuh"
 
-// #include "intellisense_cuda_intrinsics.h"
-
 using uint32_t = unsigned int;
 
-//static __device__  bool debug_row() { return blockIdx.x == 0; };
-//static __device__ bool debug_thread() { return threadIdx.x == 0; }
-
 __constant__ double DOUBLE_EPS = 1.11022E-16; // 2 ^ -53
 __constant__ double FLOAT_EPS = 1.49012E-08; // 2 ^ -26
 __constant__ double EPSILON = 1E-11; // margin for comparisons ToDo: make 
consistent use of it
@@ -79,12 +72,6 @@ __device__ Vector<T>& getVector(MatrixAccessor<T>& data, 
uint32_t n, uint32_t ri
                c[i] = data.val(rix, i);
                i += blockDim.x;
        }
-//     if(debug_thread()) {
-//             printf("getVector: c.len=%d rix=%d\n", c.length, rix);
-//             for(auto j = 0; j < c.length; ++j)
-//                     printf("%4.3f ", c[j]);
-//             printf("\n");
-//     }
        return c;
 }
 
@@ -146,122 +133,147 @@ __device__ T BLOCK_ROW_AGG(T *a, T *b, uint32_t len, 
AggOp agg_op, LoadOp load_o
        auto sdata = shared_memory_proxy<T>();
        uint tid = threadIdx.x;
 
-       // Initalize shared mem and leave if tid > row length. 
-//     if(tid >= len) { return sdata[tid] = AggOp::init();; }
-
-       __syncthreads();
-       
-//                      if(blockIdx.x == 0 && threadIdx.x == 0)
-//                printf("tid=%d sdata[tid + 128]=%f, len=%d\n", tid, len, 
sdata[tid+128]);
        uint i = tid;
        T v = AggOp::init();
-//                      if(blockIdx.x == 0 && threadIdx.x == 0)
-//                printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
-       while (i < len) {
+       while(i < len) {
                v = agg_op(v, load_op(a[i], b[i]));
                i += blockDim.x;
        }
 
-//              if(blockIdx.x == 0 && threadIdx.x == 0)
-//     if(debug_row() && debug_thread())
-//                printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
-       
        // each thread puts its local sum into shared memory
        sdata[tid] = v;
-       // if(blockIdx.x==0)
-               // printf("tid=%d v=%f, len=%d\n", tid, v, len);
        __syncthreads();
 
-                       // if(blockIdx.x == 0 && threadIdx.x == 0)
-                //  printf("tid=%d sdata[tid + 128]=%f\n", tid, 
sdata[tid+128]);
-       
        // do reduction in shared mem
-       if (blockDim.x >= 1024) {
-               if (tid < 512 && (tid+512) < len) {
-                               // if(blockIdx.x == 0 && threadIdx.x == 0)
-                 // printf("tid=%d sdata[tid + 512]=%f\n", tid, 
sdata[tid+512]);
+       if(blockDim.x >= 1024) {
+               if(tid < 512 && (tid+512) < len) {
                        sdata[tid] = v = agg_op(v, sdata[tid + 512]);
                }
                __syncthreads();
        }
-       if (blockDim.x >= 512) {
-               if (tid < 256 && (tid+256) < len) {
-                               // if(blockIdx.x == 0 && threadIdx.x == 0)
-                 // printf("tid=%d sdata[tid + 256]=%f\n", tid, 
sdata[tid+256]);
+       if(blockDim.x >= 512) {
+               if(tid < 256 && (tid+256) < len) {
                        sdata[tid] = v = agg_op(v, sdata[tid + 256]);
                }
                __syncthreads();
        }
-       if (blockDim.x >= 256) {
-               if (tid < 128 && (tid+128) < len) {
-                               // if(blockIdx.x == 0 && threadIdx.x == 0)
-                 // printf("tid=%d sdata[tid + 128]=%f\n", tid, 
sdata[tid+128]);
+       if(blockDim.x >= 256) {
+               if(tid < 128 && (tid+128) < len) {
                        sdata[tid] = v = agg_op(v, sdata[tid + 128]);
                }
                __syncthreads();
        }
-       if (blockDim.x >= 128) {
-               if (tid < 64 && (tid+64) < len) {
-                               // if(blockIdx.x == 0 && threadIdx.x == 0)
-                 // printf("tid=%d sdata[tid + 64]=%f\n", tid, sdata[tid+64]);
+       if(blockDim.x >= 128) {
+if(tid < 64 && (tid+64) < len) {
                        sdata[tid] = v = agg_op(v, sdata[tid + 64]);
                }
                __syncthreads();
        }
- 
-       if (tid < 32) {
+
+       if(tid < 32) {
                // now that we are using warp-synchronous programming (below)
                // we need to declare our shared memory volatile so that the 
compiler
                // doesn't reorder stores to it and induce incorrect behavior.
                volatile T *smem = sdata;
-               if (blockDim.x >= 64 && (tid+32) < len) {
+               if(blockDim.x >= 64 && (tid+32) < len) {
                        smem[tid] = v = agg_op(v, smem[tid + 32]);
                }
-               // if(blockIdx.x==0)
-                 // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
-               if (blockDim.x >= 32 && (tid+16) < len) {
+               if(blockDim.x >= 32 && (tid+16) < len) {
                        smem[tid] = v = agg_op(v, smem[tid + 16]);
                }
-               // if(blockIdx.x==0)
-                 // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
-               if (blockDim.x >= 16 && (tid+8) < len) {
+               if(blockDim.x >= 16 && (tid+8) < len) {
                        smem[tid] = v = agg_op(v, smem[tid + 8]);
                }
-               // if(blockIdx.x==0)
-                 // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
-               if (blockDim.x >= 8 && (tid+4) < len) {
+               if(blockDim.x >= 8 && (tid+4) < len) {
                        smem[tid] = v = agg_op(v, smem[tid + 4]);
                }
-               // if(blockIdx.x==0 && threadIdx.x ==0)
-                 // printf("tid=%d smem[tid + 4]=%f\n", tid, smem[tid+4]);
-               if (blockDim.x >= 4 && (tid+2) < len) {
+               if(blockDim.x >= 4 && (tid+2) < len) {
                        smem[tid] = v = agg_op(v, smem[tid + 2]);
                }
-               // if(blockIdx.x==0 && threadIdx.x ==0)
-                 // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
-               if (blockDim.x >= 2 && (tid+1) < len) {
-               // if (blockDim.x >= 2) {
+               if(blockDim.x >= 2 && (tid+1) < len) {
                        smem[tid] = v = agg_op(v, smem[tid + 1]);
                }
-//              if(blockIdx.x==0 && threadIdx.x ==0)
-//             if(debug_row() && debug_thread())
-//                printf("tid=%d smem[0]=%f\n", tid, smem[0]);
        }
-
        __syncthreads();
        return sdata[0];
 }
 
+
+template<typename T, typename AggOp, typename LoadOp>
+__device__ T BLOCK_ROW_AGG(T *a, T *b, uint32_t* aix, uint32_t len, AggOp 
agg_op, LoadOp load_op) {
+    auto sdata = shared_memory_proxy<T>();
+    uint tid = threadIdx.x;
+
+    uint i = tid;
+    T v = AggOp::init();
+    while(i < len) {
+        v = agg_op(v, load_op(a[i], b[aix[i]]));
+        i += blockDim.x;
+    }
+
+    // each thread puts its local sum into shared memory
+    sdata[tid] = v;
+    __syncthreads();
+
+    // do reduction in shared mem
+    if(blockDim.x >= 1024) {
+        if(tid < 512 && (tid+512) < len) {
+            sdata[tid] = v = agg_op(v, sdata[tid + 512]);
+        }
+        __syncthreads();
+    }
+    if(blockDim.x >= 512) {
+        if(tid < 256 && (tid+256) < len) {
+            sdata[tid] = v = agg_op(v, sdata[tid + 256]);
+        }
+        __syncthreads();
+    }
+    if(blockDim.x >= 256) {
+        if(tid < 128 && (tid+128) < len) {
+            sdata[tid] = v = agg_op(v, sdata[tid + 128]);
+        }
+        __syncthreads();
+    }
+    if(blockDim.x >= 128) {
+        if(tid < 64 && (tid+64) < len) {
+            sdata[tid] = v = agg_op(v, sdata[tid + 64]);
+        }
+        __syncthreads();
+    }
+
+    if(tid < 32) {
+        // now that we are using warp-synchronous programming (below)
+        // we need to declare our shared memory volatile so that the compiler
+        // doesn't reorder stores to it and induce incorrect behavior.
+        volatile T *smem = sdata;
+        if(blockDim.x >= 64 && (tid+32) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 32]);
+        }
+        if(blockDim.x >= 32 && (tid+16) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 16]);
+        }
+        if(blockDim.x >= 16 && (tid+8) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 8]);
+        }
+        if(blockDim.x >= 8 && (tid+4) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 4]);
+        }
+        if(blockDim.x >= 4 && (tid+2) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 2]);
+        }
+        if(blockDim.x >= 2 && (tid+1) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 1]);
+        }
+    }
+    __syncthreads();
+    return sdata[0];
+}
+
 template<typename T>
 __device__ T dotProduct(T* a, T* b, uint32_t ai, uint32_t bi, uint32_t len) {
        SumOp<T> agg_op;
        ProductOp<T> load_op;
-//     if(debug_row() && debug_thread())
-//             printf("dot len = %d\n", len);
-       T ret =  BLOCK_ROW_AGG(&a[ai], &b[bi], len, agg_op, load_op);
-//     if(debug_row() && debug_thread())
-//             printf("bid=%d, ai=%d, dot=%f\n", blockIdx.x, ai, ret);
-       return ret;
+       return BLOCK_ROW_AGG(&a[ai], &b[bi], len, agg_op, load_op);
 }
 
 template<typename T>
@@ -277,8 +289,6 @@ __device__ T vectSum(T* a, uint32_t ai, uint32_t len) {
        SumOp<T> agg_op;
        IdentityOp<T> load_op;
        T result = BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
-//     if(debug_row() && debug_thread())
-//             printf("vectSum: bid=%d, tid=%d ai=%d len=%d result=%4.3f\n", 
blockIdx.x, threadIdx.x, ai, len, result);
        return result;
 }
 
@@ -286,10 +296,22 @@ template<typename T>
 __device__ T vectMin(T* a, int ai, int len) {
        MinOp<T> agg_op;
        IdentityOp<T> load_op;
-       T result = BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
-//     if(debug_row() && debug_thread())
-//             printf("vectMin: bid=%d, tid=%d ai=%d len=%d result=%4.3f\n", 
blockIdx.x, threadIdx.x, ai, len, result);
-       return result;
+       return BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
+}
+
+template<typename T>
+__device__ T rowMaxsVectMult(T* a, T* b, uint32_t ai, uint32_t bi, uint32_t 
len) {
+    MaxOp<T> agg_op;
+    ProductOp<T> load_op;
+    return BLOCK_ROW_AGG(&a[ai], &b[0], len, agg_op, load_op);
+}
+
+template<typename T>
+__device__ T rowMaxsVectMult(T* a, T* b, uint32_t* aix, uint32_t ai, uint32_t 
bi, uint32_t len) {
+    MaxOp<T> agg_op;
+    ProductOp<T> load_op;
+
+    return BLOCK_ROW_AGG(&a[ai], &b[0], &aix[ai], len, agg_op, load_op);
 }
 
 template<typename T>
@@ -302,25 +324,7 @@ __device__ T vectMax(T* a, uint32_t ai, uint32_t len) {
 
 template<typename T>
 __device__ T vectMax(T* avals, uint32_t* aix, uint32_t ai, uint32_t alen, 
uint32_t len) {
-//     if (debug_row() && debug_thread()) {
-//             printf("\naix[i]:\n");
-//             for(auto i = 0; i < alen; ++i)
-//                     printf(" %d", aix[i]);
-               
-//             printf("\navals[i]:\n");
-//             for(auto i = 0; i < alen; ++i)
-//                     printf(" %4.3f", avals[i]);
-               
-//             printf("\navals[aix[i]]:\n");
-//             for(auto i = 0; i < alen; ++i)
-//                     printf(" %4.3f", avals[aix[i]]);
-
-//             printf("\n");
-//     }
-
        T result = vectMax(avals, ai, alen);
-//     if (blockIdx.x < 5 && debug_thread())
-//             printf("bid=%d, tid=%d, len=%d, alen=%d, ai=%d 
vectMax=%4.3f\n", blockIdx.x, threadIdx.x, len, alen, ai, result);
        return alen < len ? MaxOp<T>::exec(result, 0.0) : result;
 }
 
@@ -547,6 +551,12 @@ Vector<T>& vectMultWrite(T* a, T* b, uint32_t ai, uint32_t 
bi, uint32_t len, Tem
        return vectWriteBinary<T, ProductOp<T>>(a, b, ai, bi, len, fop, "Mult");
 }
 
+// sparse-dense MxV
+template<typename T>
+Vector<T>& vectMultWrite(T* avals, T* b, uint32_t* aix, uint32_t ai, uint32_t 
bi, uint32_t alen, uint32_t len, TempStorage<T>* fop) {
+    return vectWriteBinary<T, ProductOp<T>>(avals, b, aix, ai, bi, alen, len, 
fop, "Mult");
+}
+
 template<typename T>
 Vector<T>& vectDivWrite(T* a, T b, uint32_t ai, uint32_t len, TempStorage<T>* 
fop) {
        return vectWriteBinary<T, DivOp<T>>(a, b, ai, len, fop, "Div");
@@ -744,6 +754,3 @@ void vectOuterMultAdd(T* a, T* b, T* c, uint32_t ai, 
uint32_t bi, uint32_t ci, u
                i += blockDim.x;
        }
 }
-
-
-#endif // SPOOF_UTILS_CUH
diff --git a/src/main/cuda/headers/vector_write.cuh 
b/src/main/cuda/headers/vector_write.cuh
index 3099926167..55241bd8d4 100644
--- a/src/main/cuda/headers/vector_write.cuh
+++ b/src/main/cuda/headers/vector_write.cuh
@@ -18,10 +18,9 @@
  */
 
 #pragma once
-#ifndef SYSTEMDS_VECTOR_WRITE_CUH
-#define SYSTEMDS_VECTOR_WRITE_CUH
 
-__device__ bool debug_row() { return blockIdx.x == 1; };
+#define DEBUG_ROW 2
+__device__ bool debug_row() { return blockIdx.x == DEBUG_ROW; };
 __device__ bool debug_thread() { return threadIdx.x == 0; }
 
 // unary transform vector by OP and write to intermediate vector
@@ -143,6 +142,19 @@ __device__ Vector<T>& vectWriteBinary(T* a, T* b, uint32_t 
ai, uint32_t bi, uint
        return c;
 }
 
+// sparse binary vect-vect to intermediate vector
+template<typename T, typename OP>
+__device__ Vector<T>& vectWriteBinary(T* a, T* b, uint32_t* aix, uint32_t ai, 
uint32_t bi, uint32_t alen, uint32_t len,
+        TempStorage<T>* fop, const char* name = nullptr) {
+    uint32_t i = threadIdx.x;
+    Vector<T>& c = fop->getTempStorage(len);
+    while (i < alen) {
+        c[aix[ai+i]] = OP::exec(a[ai + i], b[aix[ai+i]]);
+        i += blockDim.x;
+    }
+    return c;
+}
+
 // binary vector-scalar to output vector c
 template<typename T, typename OP>
 __device__ void vectWriteBinary(T* a, T b, T* c, uint32_t ai, uint32_t ci, 
uint32_t len) {
@@ -168,5 +180,3 @@ __device__ void vectWriteBinary(T* a, T* b, T* c, uint32_t 
ai, uint32_t bi, uint
                i += blockDim.x;
        }
 }
-
-#endif //SYSTEMDS_VECTOR_WRITE_CUH
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp 
b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
index 2ef482d18d..c4ae1e3dff 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
@@ -47,11 +47,7 @@ size_t SpoofCUDAContext::initialize_cuda(uint32_t device_id, 
const char* resourc
        s1 << "-I" << resource_path << "/cuda/headers";
        s2 << "-I" << resource_path << "/cuda/spoof";
        auto ctx = new SpoofCUDAContext(resource_path,{s1.str(), s2.str(), 
cuda_include_path});
-       // cuda device is handled by jCuda atm
-       //cudaSetDevice(device_id);
-       //cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync);
-       //cudaDeviceSynchronize();
-       
+
        CHECK_CUDA(cuModuleLoad(&(ctx->reductions), 
std::string(ctx->resource_path + 
std::string("/cuda/kernels/reduction.ptx")).c_str()));
        
        CUfunction func;
@@ -87,30 +83,13 @@ void SpoofCUDAContext::destroy_cuda(SpoofCUDAContext *ctx, 
[[maybe_unused]] uint
        cudaFreeHost(ctx->staging_buffer);
        cudaFree(ctx->device_buffer);
        delete ctx;
-       // cuda device is handled by jCuda atm
-       //cudaDeviceReset();
 }
 
 size_t SpoofCUDAContext::compile(std::unique_ptr<SpoofOperator> op, const 
std::string &src) {
-#ifndef NDEBUG
-//     std::cout << "---=== START source listing of spoof cuda kernel [ " << 
name << " ]: " << std::endl;
-//    uint32_t line_num = 0;
-//     std::istringstream src_stream(src);
-//    for(std::string line; std::getline(src_stream, line); line_num++)
-//             std::cout << line_num << ": " << line << std::endl;
-//     std::cout << "---=== END source listing of spoof cuda kernel [ " << 
name << " ]." << std::endl;
-       std::cout << "cwd: " << std::filesystem::current_path() << std::endl;
-       std::cout << "include_paths: ";
-       for_each (include_paths.begin(), include_paths.end(), [](const 
std::string& line){ std::cout << line << '\n';});
-       std::cout << std::endl;
-#endif
-
-// uncomment all related lines for temporary timing output:
 //     auto compile_start = clk::now();
        op->program = 
std::make_unique<jitify::Program>(kernel_cache.program(src, 0, include_paths));
 //     auto compile_end = clk::now();
 //     auto compile_duration = std::chrono::duration_cast<sec>(compile_end - 
compile_start).count();
-
        compiled_ops.push_back(std::move(op));
 //     compile_total += compile_duration;
 //     std::cout << name << " compiled in "
diff --git a/src/main/cuda/spoof-launcher/SpoofRowwise.h 
b/src/main/cuda/spoof-launcher/SpoofRowwise.h
index 4465ac99fa..01ec5206aa 100644
--- a/src/main/cuda/spoof-launcher/SpoofRowwise.h
+++ b/src/main/cuda/spoof-launcher/SpoofRowwise.h
@@ -18,15 +18,13 @@
  */
 
 #pragma once
-#ifndef SYSTEMDS_SPOOFROWWISE_H
-#define SYSTEMDS_SPOOFROWWISE_H
 
 #include "SpoofCUDAContext.h"
 #include <algorithm>
 
 template <typename T>
 struct SpoofRowwise {
-       
+
        static void exec([[maybe_unused]] SpoofCUDAContext* ctx, SpoofOperator* 
_op, DataBufferWrapper* dbw)  {
                uint32_t NT=256;
                T value_type;
@@ -56,7 +54,7 @@ struct SpoofRowwise {
                        
CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&d_temp), temp_buf_size));
                        CHECK_CUDART(cudaMemsetAsync(d_temp, 0, temp_buf_size, 
op->stream));
                }
-               
+
                std::string op_name(op->name + "_DENSE");
                if(sparse_input)
                        op_name = std::string(op->name + "_SPARSE");
@@ -77,5 +75,3 @@ struct SpoofRowwise {
                        CHECK_CUDART(cudaFree(d_temp));
        }
 };
-
-#endif //SYSTEMDS_SPOOFROWWISE_H
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.cpp 
b/src/main/cuda/spoof-launcher/jni_bridge.cpp
index 5134d5e292..65f4a5a19f 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.cpp
+++ b/src/main/cuda/spoof-launcher/jni_bridge.cpp
@@ -165,7 +165,7 @@ int launch_spoof_operator([[maybe_unused]] JNIEnv *jenv, 
[[maybe_unused]] jclass
 
 [[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f
        (JNIEnv *jenv, jclass jobj, jlong ctx) {
-       return launch_spoof_operator<double, SpoofCellwise<double>>(jenv, jobj, 
ctx);
+       return launch_spoof_operator<float, SpoofCellwise<float>>(jenv, jobj, 
ctx);
 }
 
 
@@ -177,5 +177,5 @@ int launch_spoof_operator([[maybe_unused]] JNIEnv *jenv, 
[[maybe_unused]] jclass
 
 [[maybe_unused]] JNIEXPORT jint JNICALL 
Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f
        (JNIEnv *jenv, jclass jobj, jlong ctx) {
-       return launch_spoof_operator<double, SpoofRowwise<double>>(jenv, jobj, 
ctx);
+       return launch_spoof_operator<float, SpoofRowwise<float>>(jenv, jobj, 
ctx);
 }
\ No newline at end of file
diff --git a/src/main/cuda/spoof/rowwise.cu b/src/main/cuda/spoof/rowwise.cu
index b31ce0c2ce..917b8e7087 100644
--- a/src/main/cuda/spoof/rowwise.cu
+++ b/src/main/cuda/spoof/rowwise.cu
@@ -48,15 +48,14 @@ struct SpoofRowwiseOp //%HAS_TEMP_VECT%
                a.init(A);
                c.init(C);
                
-               if(B)
-                       for(auto i = 0; i < NUM_B; ++i)
-                               b[i].init(&(B[i]));
+               if(B) {
+                   for(auto i = 0; i < NUM_B; ++i)
+                       b[i].init(&(B[i]));
+               }
        }
 
        __device__  __forceinline__ void exec_dense(uint32_t ai, uint32_t ci, 
uint32_t rix) {
 //%BODY_dense%
-               if (debug_row() && debug_thread())
-                       printf("c[0]=%4.3f\n", c.vals(0)[0]);
        }
 
        __device__  __forceinline__ void exec_sparse(uint32_t ai, uint32_t ci, 
uint32_t rix) {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
index 88844d6fbb..13c8e10fca 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
@@ -56,9 +56,12 @@ public class CNodeRow extends CNodeTpl
        private static final String TEMPLATE_NOAGG_OUT   = "    
LibSpoofPrimitives.vectWrite(%IN%, c, ci, %LEN%);\n";
        private static final String TEMPLATE_NOAGG_CONST_OUT_CUDA   = 
"\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
        private static final String TEMPLATE_NOAGG_OUT_CUDA   = 
"\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
-       private static final String TEMPLATE_ROWAGG_OUT_CUDA  = 
"\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n//printf(\"rix=%d 
TMP7=%f TMP8=%f %IN%=%f\\n\",rix, TMP7, TMP8,%IN%);\n}\n";
+//     private static final String TEMPLATE_ROWAGG_OUT_CUDA  = 
"\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n//printf(\"rix=%d 
TMP7=%f TMP8=%f %IN%=%f\\n\",rix, TMP7, TMP8,%IN%);\n}\n";
+private static final String TEMPLATE_ROWAGG_OUT_CUDA  = "\t\tif(threadIdx.x == 
0){\n\t\t\t*(c.vals(rix)) = %IN%;\n\t\t}\n";
+//     private static final String TEMPLATE_FULLAGG_OUT_CUDA =
+//             "\t\tif(threadIdx.x == 0) {\n\t\t\tT old = atomicAdd(c.vals(0), 
%IN%);\n//\t\t\tprintf(\"bid=%d full_agg add %f to %f\\n\",blockIdx.x, %IN%, 
old);\n\t\t}\n";
        private static final String TEMPLATE_FULLAGG_OUT_CUDA =
-               "\t\tif(threadIdx.x == 0) {\n\t\t\tT old = atomicAdd(c.vals(0), 
%IN%);\n//\t\t\tprintf(\"bid=%d full_agg add %f to %f\\n\",blockIdx.x, %IN%, 
old);\n\t\t}\n";
+               "\t\tif(threadIdx.x == 0) {\n\t\tT old = atomicAdd(c.vals(0), 
%IN%);\n\t\t}\n";
 
 
        public CNodeRow(ArrayList<CNode> inputs, CNode output ) {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
index 6d826b16b4..ec46e1196c 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
@@ -42,267 +42,156 @@ public class Binary extends CodeTemplate
                                        "\t\tVector<T>& %TMP% = 
vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN1%, %LEN2%, 
this);\n" :
                                        "\t\tVector<T>& %TMP% = 
vectCbindWrite(%IN1%, %IN2%, %POS1%, %POS2%, %LEN1%, %LEN2%, this);\n";
                }
-               
-               if(isSinglePrecision()) {
-                       switch(type) {
-                               case DOT_PRODUCT:
-                                       return sparseLhs ? "    T %TMP% = 
LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" 
: "    T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, 
%LEN%);\n";
-                               case VECT_MATRIXMULT:
-                                       return sparseLhs ? "    T[] %TMP% = 
LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, 
len);\n" : " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, 
%POS1%, %POS2%, %LEN%);\n";
-                               case VECT_OUTERMULT_ADD:
-                                       return sparseLhs ? "    
LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, 
%POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? "   
LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, 
%POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "       
LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, 
%POSOUT%, %LEN1%, %LEN2%);\n";
 
-                               //vector-scalar-add operations
-                               case VECT_MULT_ADD:
-                               case VECT_DIV_ADD:
-                               case VECT_MINUS_ADD:
-                               case VECT_PLUS_ADD:
-                               case VECT_POW_ADD:
-                               case VECT_XOR_ADD:
-                               case VECT_MIN_ADD:
-                               case VECT_MAX_ADD:
-                               case VECT_EQUAL_ADD:
-                               case VECT_NOTEQUAL_ADD:
-                               case VECT_LESS_ADD:
-                               case VECT_LESSEQUAL_ADD:
-                               case VECT_GREATER_ADD:
-                               case VECT_GREATEREQUAL_ADD:
-                               case VECT_CBIND_ADD: {
-                                       String vectName = 
type.getVectorPrimitiveName();
-                                       if(scalarVector)
-                                               return sparseLhs ? "    
LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, 
%POS2%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + 
"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
-                                       else
-                                               return sparseLhs ? "    
LibSpoofPrimitives.vect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, 
%POS1%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + 
"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
-                               }
+               switch(type) {
+                       case ROWMAXS_VECTMULT:
+                               return sparseLhs ? "\t\tT %TMP% = 
rowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
+                                               "\t\tT %TMP% = 
rowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+                       case DOT_PRODUCT:
+                               return sparseLhs ? "\t\tT %TMP% = 
dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "               T 
%TMP% = dotProduct(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), 
%LEN%);\n";
 
-                               //vector-scalar operations
-                               case VECT_MULT_SCALAR:
-                               case VECT_DIV_SCALAR:
-                               case VECT_MINUS_SCALAR:
-                               case VECT_PLUS_SCALAR:
-                               case VECT_POW_SCALAR:
-                               case VECT_XOR_SCALAR:
-                               case VECT_BITWAND_SCALAR:
-                               case VECT_MIN_SCALAR:
-                               case VECT_MAX_SCALAR:
-                               case VECT_EQUAL_SCALAR:
-                               case VECT_NOTEQUAL_SCALAR:
-                               case VECT_LESS_SCALAR:
-                               case VECT_LESSEQUAL_SCALAR:
-                               case VECT_GREATER_SCALAR:
-                               case VECT_GREATEREQUAL_SCALAR: {
-                                       String vectName = 
type.getVectorPrimitiveName();
-                                       if(scalarVector)
-                                               return sparseRhs ? "    T[] 
%TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, 
%POS2%, alen, %LEN%);\n" : "    T[] %TMP% = LibSpoofPrimitives.vect" + vectName 
+ "Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
-                                       else
-                                               return sparseLhs ? "    T[] 
%TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, 
%POS1%, alen, %LEN%);\n" : "    T[] %TMP% = LibSpoofPrimitives.vect" + vectName 
+ "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
-                               }
-                               //vector-vector operations
-                               case VECT_MULT:
-                               case VECT_DIV:
-                               case VECT_MINUS:
-                               case VECT_PLUS:
-                               case VECT_XOR:
-                               case VECT_BITWAND:
-                               case VECT_BIASADD:
-                               case VECT_BIASMULT:
-                               case VECT_MIN:
-                               case VECT_MAX:
-                               case VECT_EQUAL:
-                               case VECT_NOTEQUAL:
-                               case VECT_LESS:
-                               case VECT_LESSEQUAL:
-                               case VECT_GREATER:
-                               case VECT_GREATEREQUAL: {
-                                       String vectName = 
type.getVectorPrimitiveName();
-                                       return sparseLhs ? "    T[] %TMP% = 
LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, 
%POS2%, alen, %LEN%);\n" : sparseRhs ? "        T[] %TMP% = 
LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, 
%POS2%, alen, %LEN%);\n" : "    T[] %TMP% = LibSpoofPrimitives.vect" + vectName 
+ "Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
-                               }
+                       case VECT_MATRIXMULT:
+                               return sparseLhs ? "    T[] %TMP% = 
vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : "        
    Vector<T>& %TMP% = vectMatrixMult(%IN1%, %IN2%, %POS1%, 
static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
+                       case VECT_OUTERMULT_ADD:
+                               return sparseLhs ? "    
LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, 
%POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? "   
LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, 
%POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\t\tvectOuterMultAdd(%IN1%, 
%IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
 
-                               //scalar-scalar operations
-                               case MULT:
-                                       return "        T %TMP% = %IN1% * 
%IN2%;\n";
-                               case DIV:
-                                       return "        T %TMP% = %IN1% / 
%IN2%;\n";
-                               case PLUS:
-                                       return "        T %TMP% = %IN1% + 
%IN2%;\n";
-                               case MINUS:
-                                       return "        T %TMP% = %IN1% - 
%IN2%;\n";
-                               case MODULUS:
-                                       return "        T %TMP% = 
modulus(%IN1%, %IN2%);\n";
-                               case INTDIV:
-                                       return "        T %TMP% = intDiv(%IN1%, 
%IN2%);\n";
-                               case LESS:
-                                       return "        T %TMP% = (%IN1% < 
%IN2%) ? 1 : 0;\n";
-                               case LESSEQUAL:
-                                       return "        T %TMP% = (%IN1% <= 
%IN2%) ? 1 : 0;\n";
-                               case GREATER:
-                                       return "        T %TMP% = (%IN1% > 
%IN2%) ? 1 : 0;\n";
-                               case GREATEREQUAL:
-                                       return "        T %TMP% = (%IN1% >= 
%IN2%) ? 1 : 0;\n";
-                               case EQUAL:
-                                       return "        T %TMP% = (%IN1% == 
%IN2%) ? 1 : 0;\n";
-                               case NOTEQUAL:
-                                       return "        T %TMP% = (%IN1% != 
%IN2%) ? 1 : 0;\n";
-
-                               case MIN:
-                                       return "        T %TMP% = fminf(%IN1%, 
%IN2%);\n";
-                               case MAX:
-                                       return "        T %TMP% = fmaxf(%IN1%, 
%IN2%);\n";
-                               case LOG:
-                                       return "        T %TMP% = 
logf(%IN1%)/Math.log(%IN2%);\n";
-                               case LOG_NZ:
-                                       return "        T %TMP% = (%IN1% == 0) 
? 0 : logf(%IN1%) / logf(%IN2%);\n";
-                               case POW:
-                                       return "        T %TMP% = powf(%IN1%, 
%IN2%);\n";
-                               case MINUS1_MULT:
-                                       return "        T %TMP% = 1 - %IN1% * 
%IN2%;\n";
-                               case MINUS_NZ:
-                                       return "        T %TMP% = (%IN1% != 0) 
? %IN1% - %IN2% : 0;\n";
-                               case XOR:
-                                       return "        T %TMP% = ( (%IN1% != 
0) != (%IN2% != 0) ) ? 1.0f : 0.0f;\n";
-                               case BITWAND:
-                                       return "        T %TMP% = bwAnd(%IN1%, 
%IN2%);\n";
-                               case SEQ_RIX:
-                                       return "        T %TMP% = %IN1% + grix 
* %IN2%;\n"; //0-based global rix
-
-                               default:
-                                       throw new RuntimeException("Invalid 
binary type: " + this.toString());
+                       //vector-scalar-add operations
+                       case VECT_MULT_ADD:
+                       case VECT_DIV_ADD:
+                       case VECT_MINUS_ADD:
+                       case VECT_PLUS_ADD:
+                       case VECT_POW_ADD:
+                       case VECT_XOR_ADD:
+                       case VECT_MIN_ADD:
+                       case VECT_MAX_ADD:
+                       case VECT_EQUAL_ADD:
+                       case VECT_NOTEQUAL_ADD:
+                       case VECT_LESS_ADD:
+                       case VECT_LESSEQUAL_ADD:
+                       case VECT_GREATER_ADD:
+                       case VECT_GREATEREQUAL_ADD:
+                       case VECT_CBIND_ADD: {
+                               String vectName = type.getVectorPrimitiveName();
+                               if(scalarVector)
+                                       return sparseLhs ? "\t\tvect" + 
vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, 
%LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, 
%POSOUT%, %LEN%);\n";
+                               else
+                                       return sparseLhs ? "\t\tvect" + 
vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, 
%LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, 
static_cast<uint32_t>(%POSOUT%), %LEN%);\n";
                        }
-               }
-               else {
-                       switch(type) {
-                               case DOT_PRODUCT:
-//                                     return sparseLhs ? "    T %TMP% = 
LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" 
: "    T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, 
%LEN%);\n";
-//                                     return sparseLhs ? "            T %TMP% 
= dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "               
T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n    
printf(\"dot=%f, bid=%d, tid=%d\\n\",TMP7,blockIdx.x, threadIdx.x);\n   
__syncthreads();\n";
-                                       return sparseLhs ? "            T %TMP% 
= dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "               
T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), 
%LEN%);\n";
-                               
-                               case VECT_MATRIXMULT:
-                                       return sparseLhs ? "    T[] %TMP% = 
vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : "        
    Vector<T>& %TMP% = vectMatrixMult(%IN1%, %IN2%, %POS1%, 
static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
-                               case VECT_OUTERMULT_ADD:
-                                       return sparseLhs ? "    
LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, 
%POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? "   
LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, 
%POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\t\tvectOuterMultAdd(%IN1%, 
%IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
-
-                               //vector-scalar-add operations
-                               case VECT_MULT_ADD:
-                               case VECT_DIV_ADD:
-                               case VECT_MINUS_ADD:
-                               case VECT_PLUS_ADD:
-                               case VECT_POW_ADD:
-                               case VECT_XOR_ADD:
-                               case VECT_MIN_ADD:
-                               case VECT_MAX_ADD:
-                               case VECT_EQUAL_ADD:
-                               case VECT_NOTEQUAL_ADD:
-                               case VECT_LESS_ADD:
-                               case VECT_LESSEQUAL_ADD:
-                               case VECT_GREATER_ADD:
-                               case VECT_GREATEREQUAL_ADD:
-                               case VECT_CBIND_ADD: {
-                                       String vectName = 
type.getVectorPrimitiveName();
-                                       if(scalarVector)
-                                               return sparseLhs ? "\t\tvect" + 
vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, 
%LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, 
%POSOUT%, %LEN%);\n";
-                                       else
-                                               return sparseLhs ? "\t\tvect" + 
vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, 
%LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, 
static_cast<uint32_t>(%POSOUT%), %LEN%);\n";
-                               }
 
-                               //vector-scalar operations
-                               case VECT_MULT_SCALAR:
-                               case VECT_DIV_SCALAR:
-                               case VECT_MINUS_SCALAR:
-                               case VECT_PLUS_SCALAR:
-                               case VECT_POW_SCALAR:
-                               case VECT_XOR_SCALAR:
-                               case VECT_BITWAND_SCALAR:
-                               case VECT_MIN_SCALAR:
-                               case VECT_MAX_SCALAR:
-                               case VECT_EQUAL_SCALAR:
-                               case VECT_NOTEQUAL_SCALAR:
-                               case VECT_LESS_SCALAR:
-                               case VECT_LESSEQUAL_SCALAR:
-                               case VECT_GREATER_SCALAR:
-                               case VECT_GREATEREQUAL_SCALAR: {
-                                       String vectName = 
type.getVectorPrimitiveName();
-                                       if(scalarVector)
-                                               return sparseRhs ? "    T[] 
%TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, 
%POS2%, alen, %LEN%);\n" : "            Vector<T>& %TMP% = vect" + vectName + 
"Write(%IN1%, %IN2%, %POS2%, %LEN%, this);\n";
-                                       else
+                       //vector-scalar operations
+                       case VECT_MULT_SCALAR:
+                       case VECT_DIV_SCALAR:
+                       case VECT_MINUS_SCALAR:
+                       case VECT_PLUS_SCALAR:
+                       case VECT_POW_SCALAR:
+                       case VECT_XOR_SCALAR:
+                       case VECT_BITWAND_SCALAR:
+                       case VECT_MIN_SCALAR:
+                       case VECT_MAX_SCALAR:
+                       case VECT_EQUAL_SCALAR:
+                       case VECT_NOTEQUAL_SCALAR:
+                       case VECT_LESS_SCALAR:
+                       case VECT_LESSEQUAL_SCALAR:
+                       case VECT_GREATER_SCALAR:
+                       case VECT_GREATEREQUAL_SCALAR: {
+                               String vectName = type.getVectorPrimitiveName();
+                               if(scalarVector)
+                                       return sparseRhs ? "    T[] %TMP% = 
LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, 
alen, %LEN%);\n" : "            Vector<T>& %TMP% = vect" + vectName + 
"Write(%IN1%, %IN2%, %POS2%, %LEN%, this);\n";
+                               else
 //                                             return sparseLhs ? "    T[] 
%TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, 
%POS1%, alen, %LEN%);\n" : "    T[] %TMP% = LibSpoofPrimitives.vect" + vectName 
+ "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
-                                               return sparseLhs ? "            
Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, 
alen, %LEN%, this);\n" : "          Vector<T>& %TMP% = vect" + vectName + 
"Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
-                               }
-                                       
-                               //vector-vector operations
-                               case VECT_MULT:
-                               case VECT_DIV:
-                               case VECT_MINUS:
-                               case VECT_PLUS:
-                               case VECT_XOR:
-                               case VECT_BITWAND:
-                               case VECT_BIASADD:
-                               case VECT_BIASMULT:
-                               case VECT_MIN:
-                               case VECT_MAX:
-                               case VECT_EQUAL:
-                               case VECT_NOTEQUAL:
-                               case VECT_LESS:
-                               case VECT_LESSEQUAL:
-                               case VECT_GREATER:
-                               case VECT_GREATEREQUAL: {
-                                       String vectName = 
type.getVectorPrimitiveName();
-                                       return sparseLhs ? "            
Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, 
%POS2%, " +
-                                                       "alen, %LEN%);\n" : 
sparseRhs ? "               Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, 
%IN2v%, " +
-                                                       "%POS1%, %IN2i%, 
%POS2%, alen, %LEN%);\n" : "           Vector<T>& %TMP% = vect" + vectName + 
-                                               "Write(%IN1%, %IN2%, 
static_cast<uint32_t>(%POS1%), static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
-                               }
+                                       return sparseLhs ? "            
Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, 
alen, %LEN%, this);\n" : "          Vector<T>& %TMP% = vect" + vectName + 
"Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
+                       }
 
-                               //scalar-scalar operations
-                               case MULT:
-                                       return "                T %TMP% = %IN1% 
* %IN2%;\n";
-                               case DIV:
-                                       return "        T %TMP% = %IN1% / 
%IN2%;\n";
-                               case PLUS:
-                                       return "                T %TMP% = %IN1% 
+ %IN2%;\n";
-                               case MINUS:
-                                       return "        T %TMP% = %IN1% - 
%IN2%;\n";
-                               case MODULUS:
-                                       return "        T %TMP% = 
modulus(%IN1%, %IN2%);\n";
-                               case INTDIV:
-                                       return "        T %TMP% = intDiv(%IN1%, 
%IN2%);\n";
-                               case LESS:
-                                       return "        T %TMP% = (%IN1% < 
%IN2%) ? 1.0 : 0.0;\n";
-                               case LESSEQUAL:
-                                       return "        T %TMP% = (%IN1% <= 
%IN2%) ? 1.0 : 0.0;\n";
-                               case GREATER:
-                                       return "        T %TMP% = (%IN1% > 
(%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
-                               case GREATEREQUAL:
-                                       return "        T %TMP% = (%IN1% >= 
%IN2%) ? 1.0 : 0.0;\n";
-                               case EQUAL:
-                                       return "        T %TMP% = (%IN1% == 
%IN2%) ? 1.0 : 0.0;\n";
-                               case NOTEQUAL:
-                                       return "        T %TMP% = (%IN1% != 
%IN2%) ? 1.0 : 0.0;\n";
+                       //vector-vector operations
+                       case VECT_MULT:
+                       case VECT_DIV:
+                       case VECT_MINUS:
+                       case VECT_PLUS:
+                       case VECT_XOR:
+                       case VECT_BITWAND:
+                       case VECT_BIASADD:
+                       case VECT_BIASMULT:
+                       case VECT_MIN:
+                       case VECT_MAX:
+                       case VECT_EQUAL:
+                       case VECT_NOTEQUAL:
+                       case VECT_LESS:
+                       case VECT_LESSEQUAL:
+                       case VECT_GREATER:
+                       case VECT_GREATEREQUAL: {
+                               String vectName = type.getVectorPrimitiveName();
+                               return sparseLhs ? "            Vector<T>& 
%TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, " +
+                                          "%POS1%, %POS2%, alen, %LEN%, 
this);\n" :
+                                          sparseRhs ? "                
Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, " +
+                                          "%IN2i%, %POS2%, alen, %LEN%);\n" :
+                                          "            Vector<T>& %TMP% = 
vect" + vectName + "Write(%IN1%, %IN2%, " +
+                                          "static_cast<uint32_t>(%POS1%), 
static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
+                       }
 
-                               case MIN:
-                                       return "        T %TMP% = min(%IN1%, 
%IN2%);\n";
-                               case MAX:
-                                       return "        T %TMP% = max(%IN1%, 
%IN2%);\n";
-                               case LOG:
-                                       return "        T %TMP% = 
log(%IN1%)/Math.log(%IN2%);\n";
-                               case LOG_NZ:
-                                       return "        T %TMP% = (%IN1% == 0) 
? 0 : log(%IN1%) / log(%IN2%);\n";
-                               case POW:
-                                       return "        T %TMP% = pow(%IN1%, 
%IN2%);\n";
-                               case MINUS1_MULT:
-                                       return "        T %TMP% = 1 - %IN1% * 
%IN2%;\n";
-                               case MINUS_NZ:
-                                       return "        T %TMP% = (%IN1% != 0) 
? %IN1% - %IN2% : 0;\n";
-                               case XOR:
+                       //scalar-scalar operations
+                       case MULT:
+                               return "                T %TMP% = %IN1% * 
%IN2%;\n";
+                       case DIV:
+                               return "\t\tT %TMP% = %IN1% / %IN2%;\n";
+                       case PLUS:
+                               return "\t\tT %TMP% = %IN1% + %IN2%;\n";
+                       case MINUS:
+                               return "        T %TMP% = %IN1% - %IN2%;\n";
+                       case MODULUS:
+                               return "        T %TMP% = modulus(%IN1%, 
%IN2%);\n";
+                       case INTDIV:
+                               return "        T %TMP% = intDiv(%IN1%, 
%IN2%);\n";
+                       case LESS:
+                               return "        T %TMP% = (%IN1% < %IN2%) ? 1.0 
: 0.0;\n";
+                       case LESSEQUAL:
+                               return "        T %TMP% = (%IN1% <= %IN2%) ? 
1.0 : 0.0;\n";
+                       case GREATER:
+                               return "        T %TMP% = (%IN1% > (%IN2% + 
EPSILON)) ? 1.0 : 0.0;\n";
+                       case GREATEREQUAL:
+                               return "        T %TMP% = (%IN1% >= %IN2%) ? 
1.0 : 0.0;\n";
+                       case EQUAL:
+                               return "        T %TMP% = (%IN1% == %IN2%) ? 
1.0 : 0.0;\n";
+                       case NOTEQUAL:
+                               return "\t\tT %TMP% = (%IN1% != %IN2%) ? 1.0 : 
0.0;\n";
+                       case MIN:
+                               if(isSinglePrecision())
+                                       return "\t\tT %TMP% = fminf(%IN1%, 
%IN2%);\n";
+                               else
+                                       return "\t\tT %TMP% = min(%IN1%, 
%IN2%);\n";
+                       case MAX:
+                               if(isSinglePrecision())
+                                       return "\t\tT %TMP% = fmaxf(%IN1%, 
%IN2%);\n";
+                               else
+                                       return "\t\tT %TMP% = max(%IN1%, 
%IN2%);\n";
+                       case LOG:
+                               if(isSinglePrecision())
+                                       return "\t\tT %TMP% = logf(%IN1%) / 
logf(%IN2%);\n";
+                               else
+                                       return "\t\tT %TMP% = log(%IN1%) / 
log(%IN2%);\n";
+                       case LOG_NZ:
+                               if(isSinglePrecision())
+                                       return "\t\tT %TMP% = (%IN1% == 0) ? 0 
: logf(%IN1%) / logf(%IN2%);\n";
+                               else
+                                       return "\t\tT %TMP% = (%IN1% == 0) ? 0 
: log(%IN1%) / log(%IN2%);\n";
+                       case POW:
+                               if(isSinglePrecision())
+                                       return "\t\tT %TMP% = powf(%IN1%, 
%IN2%);\n";
+                               else
+                                       return "\t\tT %TMP% = pow(%IN1%, 
%IN2%);\n";
+                       case MINUS1_MULT:
+                               return "        T %TMP% = 1 - %IN1% * %IN2%;\n";
+                       case MINUS_NZ:
+                               return "        T %TMP% = (%IN1% != 0) ? %IN1% 
- %IN2% : 0;\n";
+                       case XOR:
 //                                     return "        T %TMP% = ( (%IN1% != 
0.0) != (%IN2% != 0.0) ) ? 1.0 : 0.0;\n";
-                                       return "        T %TMP% = ( (%IN1% < 
EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
-                               case BITWAND:
-                                       return "        T %TMP% = bwAnd(%IN1%, 
%IN2%);\n";
-                               case SEQ_RIX:
-                                       return "                T %TMP% = %IN1% 
+ grix * %IN2%;\n"; //0-based global rix
+                               return "        T %TMP% = ( (%IN1% < EPSILON) 
!= (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
+                       case BITWAND:
+                               return "        T %TMP% = bwAnd(%IN1%, 
%IN2%);\n";
+                       case SEQ_RIX:
+                               return "\t\tT %TMP% = %IN1% + grix * %IN2%;\n"; 
//0-based global rix
 
-                               default:
-                                       throw new RuntimeException("Invalid 
binary type: " + this.toString());
-                       }
+                       default:
+                               throw new RuntimeException("Invalid binary 
type: " + this.toString());
                }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
index dd06d6c004..026fe264f8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
@@ -28,81 +28,41 @@ public class Ternary extends CodeTemplate {
 
        @Override
        public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
-               if(isSinglePrecision()) {
-                       switch (type) {
-                               case PLUS_MULT:
-                                       return "        T %TMP% = %IN1% + %IN2% 
* %IN3%;\n";
+               switch (type) {
+                       case PLUS_MULT:
+                               return "        T %TMP% = %IN1% + %IN2% * 
%IN3%;\n";
 
-                               case MINUS_MULT:
-                                       return "        T %TMP% = %IN1% - %IN2% 
* %IN3%;\n";
+                       case MINUS_MULT:
+                               return "        T %TMP% = %IN1% - %IN2% * 
%IN3%;\n";
 
-                               case BIASADD:
-                                       return "        T %TMP% = %IN1% + 
getValue(%IN2%, cix/%IN3%);\n";
+                       case BIASADD:
+                               return "        T %TMP% = %IN1% + 
getValue(%IN2%, cix/%IN3%);\n";
 
-                               case BIASMULT:
-                                       return "        T %TMP% = %IN1% * 
getValue(%IN2%, cix/%IN3%);\n";
+                       case BIASMULT:
+                               return "        T %TMP% = %IN1% * 
getValue(%IN2%, cix/%IN3%);\n";
 
-                               case REPLACE:
-                                       return "        T %TMP% = (%IN1% == 
%IN2% || (isnan(%IN1%) "
-                                                       + "&& isnan(%IN2%))) ? 
%IN3% : %IN1%;\n";
+                       case REPLACE:
+                               return "        T %TMP% = (%IN1% == %IN2% || 
(isnan(%IN1%) "
+                                               + "&& isnan(%IN2%))) ? %IN3% : 
%IN1%;\n";
 
-                               case REPLACE_NAN:
-                                       return "        T %TMP% = isnan(%IN1%) 
? %IN3% : %IN1%;\n";
+                       case REPLACE_NAN:
+                               return "        T %TMP% = isnan(%IN1%) ? %IN3% 
: %IN1%;\n";
 
-                               case IFELSE:
-                                       return "        T %TMP% = (%IN1% != 0) 
? %IN2% : %IN3%;\n";
+                       case IFELSE:
+                               return "        T %TMP% = (%IN1% != 0) ? %IN2% 
: %IN3%;\n";
 
-                               case LOOKUP_RC1:
-                                       return sparse ?
-                                                       "       T %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
+                       case LOOKUP_RC1:
+                               return sparse ?
+                                               "       T %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
 //                                                     "       T %TMP% = 
getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
-                                                       "               T %TMP% 
= %IN1%.val(rix, %IN3%-1);\n";
+                                               "               T %TMP% = 
%IN1%.val(rix, %IN3%-1);\n";
 
-                               case LOOKUP_RVECT1:
-                                       return "\t\tVector<T>& %TMP% = 
getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
 
-                               default:
-                                       throw new RuntimeException("Invalid 
ternary type: " + this.toString());
-                       }
-               }
-               else {
-                       switch (type) {
-                               case PLUS_MULT:
-                                       return "        T %TMP% = %IN1% + %IN2% 
* %IN3%;\n";
-
-                               case MINUS_MULT:
-                                       return "        T %TMP% = %IN1% - %IN2% 
* %IN3%;\n";
-
-                               case BIASADD:
-                                       return "        T %TMP% = %IN1% + 
getValue(%IN2%, cix/%IN3%);\n";
-
-                               case BIASMULT:
-                                       return "        T %TMP% = %IN1% * 
getValue(%IN2%, cix/%IN3%);\n";
-
-                               case REPLACE:
-                                       return "        T %TMP% = (%IN1% == 
%IN2% || (isnan(%IN1%) "
-                                                       + "&& isnan(%IN2%))) ? 
%IN3% : %IN1%;\n";
-
-                               case REPLACE_NAN:
-                                       return "        T %TMP% = isnan(%IN1%) 
? %IN3% : %IN1%;\n";
-
-                               case IFELSE:
-                                       return "        T %TMP% = (%IN1% != 0) 
? %IN2% : %IN3%;\n";
-
-                               case LOOKUP_RC1:
-                                       return sparse ?
-                                                       "       T %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
-//                                                     "       T %TMP% = 
getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
-                                                       "               T %TMP% 
= %IN1%.val(rix, %IN3%-1);\n";
-                               
-                               
-                               case LOOKUP_RVECT1:
-                                       return "\t\tVector<T>& %TMP% = 
getVector(%IN1%, %IN2%, rix, %IN3%-1, this);\n";
-
-                               default:
-                                       throw new RuntimeException("Invalid 
ternary type: "+this.toString());
-                       }
+                       case LOOKUP_RVECT1:
+                               return "\t\tVector<T>& %TMP% = getVector(%IN1%, 
%IN2%, rix, %IN3%-1, this);\n";
 
+                       default:
+                               throw new RuntimeException("Invalid ternary 
type: "+this.toString());
                }
        }
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
index f2405d5b5c..405b880715 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
@@ -29,210 +29,161 @@ public class Unary extends CodeTemplate {
 
        @Override
        public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
-               if(isSinglePrecision()) {
-                       switch( type ) {
-                               case ROW_SUMS:
-                               case ROW_SUMSQS:
-                               case ROW_MINS:
-                               case ROW_MAXS:
-                               case ROW_MEANS:
-                               case ROW_COUNTNNZS: {
-                                       String vectName = 
StringUtils.capitalize(type.name().substring(4, 
type.name().length()-1).toLowerCase());
-                                       return sparse ? "       T %TMP% = 
LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
-                                               "       T %TMP% = 
LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
-                               }
+               switch( type ) {
+                       case ROW_SUMS:
+                       case ROW_SUMSQS:
+                       case ROW_MINS:
+                       case ROW_MAXS:
+                       case ROW_MEANS:
+                       case ROW_COUNTNNZS: {
+                               String vectName = 
StringUtils.capitalize(type.name().substring(4, 
type.name().length()-1).toLowerCase());
+                               return sparse ? "               T %TMP% = 
vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, %LEN%);\n":
+                                       "               T %TMP% = 
vect"+vectName+"(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%);\n";
 
-                               case VECT_EXP:
-                               case VECT_POW2:
-                               case VECT_MULT2:
-                               case VECT_SQRT:
-                               case VECT_LOG:
-                               case VECT_ABS:
-                               case VECT_ROUND:
-                               case VECT_CEIL:
-                               case VECT_FLOOR:
-                               case VECT_SIGN:
-                               case VECT_SIN:
-                               case VECT_COS:
-                               case VECT_TAN:
-                               case VECT_ASIN:
-                               case VECT_ACOS:
-                               case VECT_ATAN:
-                               case VECT_SINH:
-                               case VECT_COSH:
-                               case VECT_TANH:
-                               case VECT_CUMSUM:
-                               case VECT_CUMMIN:
-                               case VECT_CUMMAX:
-                               case VECT_SPROP:
-                               case VECT_SIGMOID: {
-                                       String vectName = 
type.getVectorPrimitiveName();
-                                       return sparse ? "       T[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" 
:
-                                               "       T[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
-                               }
-
-                               case EXP:
-                                       return "        T %TMP% = 
expf(%IN1%);\n";
-                               case LOOKUP_R:
-                                       return sparse ?
-                                               "       T %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
-                                               "       T %TMP% = 
getValue(%IN1%, rix);\n";
-                               case LOOKUP_C:
-                                       return "        T %TMP% = 
getValue(%IN1%, n, 0, cix);\n";
-                               case LOOKUP_RC:
-                                       return "        T %TMP% = 
getValue(%IN1%, n, rix, cix);\n";
-                               case LOOKUP0:
-                                       return "        T %TMP% = %IN1%[0];\n";
-                               case POW2:
-                                       return "        T %TMP% = %IN1% * 
%IN1%;\n";
-                               case MULT2:
-                                       return "        T %TMP% = %IN1% + 
%IN1%;\n";
-                               case ABS:
-                                       return "        T %TMP% = 
fabsf(%IN1%);\n";
-                               case SIN:
-                                       return "        T %TMP% = 
sinf(%IN1%);\n";
-                               case COS:
-                                       return "        T %TMP% = 
cosf(%IN1%);\n";
-                               case TAN:
-                                       return "        T %TMP% = 
tanf(%IN1%);\n";
-                               case ASIN:
-                                       return "        T %TMP% = 
asinf(%IN1%);\n";
-                               case ACOS:
-                                       return "        T %TMP% = 
acosf(%IN1%);\n";
-                               case ATAN:
-                                       return "        T %TMP% = 
atanf(%IN1%);\n";
-                               case SINH:
-                                       return "        T %TMP% = 
sinhf(%IN1%);\n";
-                               case COSH:
-                                       return "        T %TMP% = 
coshf(%IN1%);\n";
-                               case TANH:
-                                       return "        T %TMP% = 
tanhf(%IN1%);\n";
-                               case SIGN:
-                                       return "        T %TMP% = 
signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
-                               case SQRT:
-                                       return "        T %TMP% = 
sqrtf(%IN1%);\n";
-                               case LOG:
-                                       return "        T %TMP% = 
logf(%IN1%);\n";
-                               case ROUND:
-                                       return "        T %TMP% = 
roundf(%IN1%);\n";
-                               case CEIL:
-                                       return "        T %TMP% = 
ceilf(%IN1%);\n";
-                               case FLOOR:
-                                       return "        T %TMP% = 
floorf(%IN1%);\n";
-                               case SPROP:
-                                       return "        T %TMP% = %IN1% * (1 - 
%IN1%);\n";
-                               case SIGMOID:
-                                       return "        T %TMP% = 1 / (1 + 
expf(-%IN1%));\n";
-                               case LOG_NZ:
-                                       return "        T %TMP% = (%IN1%==0) ? 
0 : logf(%IN1%);\n";
-
-                               default:
-                                       throw new RuntimeException("Invalid 
unary type: "+this.toString());
                        }
-               }
-               else { /* double precision */
-                       switch( type ) {
-                               case ROW_SUMS:
-                               case ROW_SUMSQS:
-                               case ROW_MINS:
-                               case ROW_MAXS:
-                               case ROW_MEANS:
-                               case ROW_COUNTNNZS: {
-                                       String vectName = 
StringUtils.capitalize(type.name().substring(4, 
type.name().length()-1).toLowerCase());
-                                       return sparse ? "               T %TMP% 
= vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, %LEN%);\n":
-                                               "               T %TMP% = 
vect"+vectName+"(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%);\n";
-                                       
-                               }
 
-                               case VECT_EXP:
-                               case VECT_POW2:
-                               case VECT_MULT2:
-                               case VECT_SQRT:
-                               case VECT_LOG:
-                               case VECT_ABS:
-                               case VECT_ROUND:
-                               case VECT_CEIL:
-                               case VECT_FLOOR:
-                               case VECT_SIGN:
-                               case VECT_SIN:
-                               case VECT_COS:
-                               case VECT_TAN:
-                               case VECT_ASIN:
-                               case VECT_ACOS:
-                               case VECT_ATAN:
-                               case VECT_SINH:
-                               case VECT_COSH:
-                               case VECT_TANH:
-                               case VECT_CUMSUM:
-                               case VECT_CUMMIN:
-                               case VECT_CUMMAX:
-                               case VECT_SPROP:
-                               case VECT_SIGMOID: {
-                                       String vectName = 
type.getVectorPrimitiveName();
-                                       return sparse ? "               
Vector<T>& %TMP% = vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, %LEN%, 
this);\n" :
-                                               "               Vector<T>& 
%TMP% = vect"+vectName+"Write(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%, 
this);\n";
-                               }
+                       case VECT_EXP:
+                       case VECT_POW2:
+                       case VECT_MULT2:
+                       case VECT_SQRT:
+                       case VECT_LOG:
+                       case VECT_ABS:
+                       case VECT_ROUND:
+                       case VECT_CEIL:
+                       case VECT_FLOOR:
+                       case VECT_SIGN:
+                       case VECT_SIN:
+                       case VECT_COS:
+                       case VECT_TAN:
+                       case VECT_ASIN:
+                       case VECT_ACOS:
+                       case VECT_ATAN:
+                       case VECT_SINH:
+                       case VECT_COSH:
+                       case VECT_TANH:
+                       case VECT_CUMSUM:
+                       case VECT_CUMMIN:
+                       case VECT_CUMMAX:
+                       case VECT_SPROP:
+                       case VECT_SIGMOID: {
+                               String vectName = type.getVectorPrimitiveName();
+                               return sparse ? "               Vector<T>& 
%TMP% = vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, %LEN%, this);\n" :
+                                       "               Vector<T>& %TMP% = 
vect"+vectName+"Write(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
+                       }
 
-                               case EXP:
+                       case EXP:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
expf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
exp(%IN1%);\n";
-                               case LOOKUP_R:
-                                       return sparse ?
-                                               "       T %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
-                                               "               T %TMP% = 
%IN1%.val(rix);\n";
-//                                             "       T %TMP% = 
getValue(%IN1%, rix);\n";
-                               case LOOKUP_C:
-                                       return "        T %TMP% = 
getValue(%IN1%, n, 0, cix);\n";
-                               case LOOKUP_RC:
-                                       return "        T %TMP% = 
getValue(%IN1%, n, rix, cix);\n";
-                               case LOOKUP0:
-                                       return "        T %TMP% = %IN1%[0];\n";
-                               case POW2:
-                                       return "        T %TMP% = %IN1% * 
%IN1%;\n";
-                               case MULT2:
-                                       return "        T %TMP% = %IN1% + 
%IN1%;\n";
-                               case ABS:
+                       case LOOKUP_R:
+                               return sparse ?
+                                       "\t\tT %TMP% = getValue(%IN1v%, %IN1i%, 
ai, alen, 0);\n" :
+//                                             "               T %TMP% = 
%IN1%.val(rix);\n";
+                                       "\t\tT %TMP% = getValue(%IN1%, rix);\n";
+                       case LOOKUP_C:
+                               return "\t\tT %TMP% = getValue(%IN1%, n, 0, 
cix);\n";
+                       case LOOKUP_RC:
+                               return "\t\tT %TMP% = getValue(%IN1%, n, rix, 
cix);\n";
+                       case LOOKUP0:
+                               return "\t\tT %TMP% = %IN1%[0];\n";
+                       case POW2:
+                               return "        T %TMP% = %IN1% * %IN1%;\n";
+                       case MULT2:
+                               return "        T %TMP% = %IN1% + %IN1%;\n";
+                       case ABS:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
fabsf(%IN1%);\n";
+                               else
                                        return "\t\tT %TMP% = fabs(%IN1%);\n";
-                               case SIN:
+                       case SIN:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
sinf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
sin(%IN1%);\n";
-                               case COS:
+                       case COS:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
cosf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
cos(%IN1%);\n";
-                               case TAN:
+                       case TAN:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
tanf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
tan(%IN1%);\n";
-                               case ASIN:
+                       case ASIN:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
asinf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
asin(%IN1%);\n";
-                               case ACOS:
+                       case ACOS:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
acosf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
acos(%IN1%);\n";
-                               case ATAN:
+                       case ATAN:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
atanf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
atan(%IN1%);\n";
-                               case SINH:
+                       case SINH:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
sinhf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
sinh(%IN1%);\n";
-                               case COSH:
+                       case COSH:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
coshf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
cosh(%IN1%);\n";
-                               case TANH:
+                       case TANH:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
tanhf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
tanh(%IN1%);\n";
-                               case SIGN:
-                                       return "        T %TMP% = 
signbit(%IN1%) == 0 ? 1.0 : -1.0;\n";
-                               case SQRT:
+                       case SIGN:
+                               return "        T %TMP% = signbit(%IN1%) == 0 ? 
1.0 : -1.0;\n";
+                       case SQRT:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
sqrtf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
sqrt(%IN1%);\n";
-                               case LOG:
+                       case LOG:
+
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
logf(%IN1%);\n";
+                               else
                                        return "                T %TMP% = 
log(%IN1%);\n";
-                               case ROUND:
+                       case ROUND:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
roundf(%IN1%);\n";
+                               else
                                        return "\t\tT %TMP% = round(%IN1%);\n";
-                               case CEIL:
+                       case CEIL:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
ceilf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
ceil(%IN1%);\n";
-                               case FLOOR:
+                       case FLOOR:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 
floorf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = 
floor(%IN1%);\n";
-                               case SPROP:
-                                       return "        T %TMP% = %IN1% * (1 - 
%IN1%);\n";
-                               case SIGMOID:
+                       case SPROP:
+                               return "        T %TMP% = %IN1% * (1 - 
%IN1%);\n";
+                       case SIGMOID:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = 1 / (1 + 
expf(-%IN1%));\n";
+                               else
                                        return "        T %TMP% = 1 / (1 + 
exp(-%IN1%));\n";
-                               case LOG_NZ:
+                       case LOG_NZ:
+                               if(isSinglePrecision())
+                                       return "        T %TMP% = (%IN1%==0) ? 
0 : logf(%IN1%);\n";
+                               else
                                        return "        T %TMP% = (%IN1%==0) ? 
0 : log(%IN1%);\n";
 
-                               default:
-                                       throw new RuntimeException("Invalid 
unary type: "+this.toString());
-                       }
-
+                       default:
+                               throw new RuntimeException("Invalid unary type: 
"+this.toString());
                }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
index 03c35da540..cfe5780326 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
@@ -116,9 +116,9 @@ public class SpoofCUDACellwise extends SpoofCellwise 
implements SpoofCUDAOperato
        }
 
        public int execute_dp(long ctx) { return execute_d(ctx); }
-       public int execute_sp(long ctx) { return execute_d(ctx); }
+       public int execute_sp(long ctx) { return execute_f(ctx); }
        public long getContext() { return ctx; }
 
        public static native int execute_d(long ctx);
-       public static native int execute_s(long ctx);
+       public static native int execute_f(long ctx);
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
index 47826a9461..0adf2ec605 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
@@ -96,9 +96,9 @@ public class SpoofCUDARowwise extends SpoofRowwise implements 
SpoofCUDAOperator
                int ci, int alen, int n, long grix, int rix) { }
 
        public int execute_dp(long ctx) { return execute_d(ctx); }
-       public int execute_sp(long ctx) { return execute_d(ctx); }
+       public int execute_sp(long ctx) { return execute_f(ctx); }
        public long getContext() { return ctx; }
 
        public static native int execute_d(long ctx);
-       public static native int execute_s(long ctx);
+       public static native int execute_f(long ctx);
 }

Reply via email to