Repository: systemml
Updated Branches:
  refs/heads/master b7f569bd0 -> 5ca8706e9


[SYSTEMML-445] Added rshape operator for the GPU backend

- This leads to 1.2x speedup for ResNet200 with batch size of 32 by reducing 
the number of host-to-device transfers.
- Also, added GPU tests for this operator.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5ca8706e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5ca8706e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5ca8706e

Branch: refs/heads/master
Commit: 5ca8706e98ba8de7418a24405d9d3bb600dfe468
Parents: b7f569b
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Thu Aug 9 09:45:42 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Thu Aug 9 09:46:02 2018 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                |  32 ++++
 src/main/cpp/kernels/SystemML.ptx               | 154 +++++++++++++++----
 .../java/org/apache/sysml/hops/ReorgOp.java     |   4 +-
 .../instructions/GPUInstructionParser.java      |   5 +
 .../instructions/gpu/GPUInstruction.java        |   1 +
 .../gpu/MatrixReshapeGPUInstruction.java        | 104 +++++++++++++
 .../runtime/matrix/data/LibMatrixCUDA.java      |   4 +-
 .../org/apache/sysml/test/gpu/ReshapeTest.java  |  95 ++++++++++++
 8 files changed, 369 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/5ca8706e/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 082daac..485b7e2 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2247,3 +2247,35 @@ extern "C" __global__ void prepare_lstm_dinput_f(float* 
smlInput, float* cudnnIn
   prepare_lstm_dinput(smlInput, cudnnInput, N, D, TD, size);
 }
 
+
+/**
+ * Do an log over all the elements of a matrix
+ * @param A the input matrix (of length = size)
+ * @param C the pre-allocated output matrix (of length = size)
+ * @param size the length of the input and output matrices
+ */
+template <typename T>
+__device__ void colwise_reshape(T *A, T *C, unsigned int size, 
+       unsigned int inRows, unsigned int inCols,
+       unsigned int outRows, unsigned int outCols) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+       int i = index / outCols;
+    int j = index % outCols;
+    int k = (outRows*j+i) % inRows;
+    int l = (outRows*j+i) / inRows;
+    C[index] = A[k*inCols+l];
+  }
+}
+
+extern "C" __global__ void colwise_reshape_d(double *A, double *C, unsigned 
int size, 
+       unsigned int inRows, unsigned int inCols,
+       unsigned int outRows, unsigned int outCols) {
+  colwise_reshape(A, C, size, inRows, inCols, outRows, outCols);
+}
+
+extern "C" __global__ void colwise_reshape_f(float *A, float *C, unsigned int 
size, 
+       unsigned int inRows, unsigned int inCols,
+       unsigned int outRows, unsigned int outCols) {
+  colwise_reshape(A, C, size, inRows, inCols, outRows, outCols);
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/5ca8706e/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx 
b/src/main/cpp/kernels/SystemML.ptx
index 9d8b178..93e5e35 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -12877,12 +12877,112 @@ BB111_2:
        ret;
 }
 
+       // .globl       colwise_reshape_d
+.visible .entry colwise_reshape_d(
+       .param .u64 colwise_reshape_d_param_0,
+       .param .u64 colwise_reshape_d_param_1,
+       .param .u32 colwise_reshape_d_param_2,
+       .param .u32 colwise_reshape_d_param_3,
+       .param .u32 colwise_reshape_d_param_4,
+       .param .u32 colwise_reshape_d_param_5,
+       .param .u32 colwise_reshape_d_param_6
+)
+{
+       .reg .pred      %p<2>;
+       .reg .b32       %r<16>;
+       .reg .f64       %fd<2>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd1, [colwise_reshape_d_param_0];
+       ld.param.u64    %rd2, [colwise_reshape_d_param_1];
+       ld.param.u32    %r6, [colwise_reshape_d_param_2];
+       ld.param.u32    %r2, [colwise_reshape_d_param_3];
+       ld.param.u32    %r3, [colwise_reshape_d_param_4];
+       ld.param.u32    %r4, [colwise_reshape_d_param_5];
+       ld.param.u32    %r5, [colwise_reshape_d_param_6];
+       mov.u32         %r7, %ctaid.x;
+       mov.u32         %r8, %ntid.x;
+       mov.u32         %r9, %tid.x;
+       mad.lo.s32      %r1, %r8, %r7, %r9;
+       setp.ge.u32     %p1, %r1, %r6;
+       @%p1 bra        BB112_2;
+
+       cvta.to.global.u64      %rd3, %rd1;
+       rem.u32         %r10, %r1, %r5;
+       div.u32         %r11, %r1, %r5;
+       mad.lo.s32      %r12, %r10, %r4, %r11;
+       rem.u32         %r13, %r12, %r2;
+       div.u32         %r14, %r12, %r2;
+       mad.lo.s32      %r15, %r13, %r3, %r14;
+       mul.wide.u32    %rd4, %r15, 8;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f64   %fd1, [%rd5];
+       cvta.to.global.u64      %rd6, %rd2;
+       mul.wide.s32    %rd7, %r1, 8;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f64   [%rd8], %fd1;
+
+BB112_2:
+       ret;
+}
+
+       // .globl       colwise_reshape_f
+.visible .entry colwise_reshape_f(
+       .param .u64 colwise_reshape_f_param_0,
+       .param .u64 colwise_reshape_f_param_1,
+       .param .u32 colwise_reshape_f_param_2,
+       .param .u32 colwise_reshape_f_param_3,
+       .param .u32 colwise_reshape_f_param_4,
+       .param .u32 colwise_reshape_f_param_5,
+       .param .u32 colwise_reshape_f_param_6
+)
+{
+       .reg .pred      %p<2>;
+       .reg .f32       %f<2>;
+       .reg .b32       %r<16>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd1, [colwise_reshape_f_param_0];
+       ld.param.u64    %rd2, [colwise_reshape_f_param_1];
+       ld.param.u32    %r6, [colwise_reshape_f_param_2];
+       ld.param.u32    %r2, [colwise_reshape_f_param_3];
+       ld.param.u32    %r3, [colwise_reshape_f_param_4];
+       ld.param.u32    %r4, [colwise_reshape_f_param_5];
+       ld.param.u32    %r5, [colwise_reshape_f_param_6];
+       mov.u32         %r7, %ctaid.x;
+       mov.u32         %r8, %ntid.x;
+       mov.u32         %r9, %tid.x;
+       mad.lo.s32      %r1, %r8, %r7, %r9;
+       setp.ge.u32     %p1, %r1, %r6;
+       @%p1 bra        BB113_2;
+
+       cvta.to.global.u64      %rd3, %rd1;
+       rem.u32         %r10, %r1, %r5;
+       div.u32         %r11, %r1, %r5;
+       mad.lo.s32      %r12, %r10, %r4, %r11;
+       rem.u32         %r13, %r12, %r2;
+       div.u32         %r14, %r12, %r2;
+       mad.lo.s32      %r15, %r13, %r3, %r14;
+       mul.wide.u32    %rd4, %r15, 4;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f32   %f1, [%rd5];
+       cvta.to.global.u64      %rd6, %rd2;
+       mul.wide.s32    %rd7, %r1, 4;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f32   [%rd8], %f1;
+
+BB113_2:
+       ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
        .param .b64 __internal_trig_reduction_slowpathd_param_0,
        .param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-       .local .align 8 .b8     __local_depot112[40];
+       .local .align 8 .b8     __local_depot114[40];
        .reg .b64       %SP;
        .reg .b64       %SPL;
        .reg .pred      %p<9>;
@@ -12891,7 +12991,7 @@ BB111_2:
        .reg .b64       %rd<102>;
 
 
-       mov.u64         %rd101, __local_depot112;
+       mov.u64         %rd101, __local_depot114;
        cvta.local.u64  %SP, %rd101;
        ld.param.f64    %fd4, [__internal_trig_reduction_slowpathd_param_0];
        ld.param.u64    %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -12905,7 +13005,7 @@ BB111_2:
        shr.u32         %r3, %r1, 20;
        bfe.u32         %r4, %r1, 20, 11;
        setp.eq.s32     %p1, %r4, 2047;
-       @%p1 bra        BB112_13;
+       @%p1 bra        BB114_13;
 
        add.s32         %r15, %r4, -1024;
        shr.u32         %r16, %r15, 6;
@@ -12918,7 +13018,7 @@ BB111_2:
        mov.u64         %rd94, 0;
        setp.ge.s32     %p2, %r5, %r6;
        mov.u64         %rd93, %rd1;
-       @%p2 bra        BB112_4;
+       @%p2 bra        BB114_4;
 
        mov.b64          %rd41, %fd4;
        shl.b64         %rd42, %rd41, 11;
@@ -12935,7 +13035,7 @@ BB111_2:
        mov.u64         %rd91, %rd1;
        mov.u32         %r39, %r5;
 
-BB112_3:
+BB114_3:
        .pragma "nounroll";
        ld.const.u64    %rd47, [%rd89];
        // inline asm
@@ -12965,15 +13065,15 @@ BB112_3:
        add.s64         %rd93, %rd93, 8;
        add.s64         %rd89, %rd89, 8;
        setp.lt.s32     %p3, %r39, %r6;
-       @%p3 bra        BB112_3;
+       @%p3 bra        BB114_3;
 
-BB112_4:
+BB114_4:
        st.local.u64    [%rd93], %rd94;
        ld.local.u64    %rd95, [%rd1+16];
        ld.local.u64    %rd96, [%rd1+24];
        and.b32         %r9, %r3, 63;
        setp.eq.s32     %p4, %r9, 0;
-       @%p4 bra        BB112_6;
+       @%p4 bra        BB114_6;
 
        mov.u32         %r27, 64;
        sub.s32         %r28, %r27, %r9;
@@ -12985,7 +13085,7 @@ BB112_4:
        shr.u64         %rd55, %rd54, %r28;
        or.b64          %rd95, %rd55, %rd53;
 
-BB112_6:
+BB114_6:
        cvta.to.local.u64       %rd56, %rd37;
        shr.u64         %rd57, %rd96, 62;
        cvt.u32.u64     %r29, %rd57;
@@ -13002,7 +13102,7 @@ BB112_6:
        selp.b32        %r34, %r32, %r33, %p5;
        st.local.u32    [%rd56], %r34;
        setp.eq.s32     %p6, %r31, 0;
-       @%p6 bra        BB112_8;
+       @%p6 bra        BB114_8;
 
        mov.u64         %rd64, 0;
        // inline asm
@@ -13022,10 +13122,10 @@ BB112_6:
        // inline asm
        xor.b32         %r40, %r40, -2147483648;
 
-BB112_8:
+BB114_8:
        clz.b64         %r41, %rd98;
        setp.eq.s32     %p7, %r41, 0;
-       @%p7 bra        BB112_10;
+       @%p7 bra        BB114_10;
 
        shl.b64         %rd67, %rd98, %r41;
        mov.u32         %r35, 64;
@@ -13033,7 +13133,7 @@ BB112_8:
        shr.u64         %rd68, %rd97, %r36;
        or.b64          %rd98, %rd68, %rd67;
 
-BB112_10:
+BB114_10:
        mov.u64         %rd72, -3958705157555305931;
        // inline asm
        {
@@ -13054,7 +13154,7 @@ BB112_10:
        }
        // inline asm
        setp.lt.s64     %p8, %rd100, 1;
-       @%p8 bra        BB112_12;
+       @%p8 bra        BB114_12;
 
        // inline asm
        {
@@ -13073,7 +13173,7 @@ BB112_10:
        // inline asm
        add.s32         %r41, %r41, 1;
 
-BB112_12:
+BB114_12:
        cvt.u64.u32     %rd79, %r40;
        shl.b64         %rd80, %rd79, 32;
        mov.u32         %r37, 1022;
@@ -13088,7 +13188,7 @@ BB112_12:
        or.b64          %rd88, %rd87, %rd80;
        mov.b64          %fd4, %rd88;
 
-BB112_13:
+BB114_13:
        st.param.f64    [func_retval0+0], %fd4;
        ret;
 }
@@ -13116,7 +13216,7 @@ BB112_13:
        }
        shr.u32         %r51, %r50, 20;
        setp.ne.s32     %p1, %r51, 0;
-       @%p1 bra        BB113_2;
+       @%p1 bra        BB115_2;
 
        mul.f64         %fd14, %fd12, 0d4350000000000000;
        {
@@ -13130,13 +13230,13 @@ BB112_13:
        shr.u32         %r16, %r50, 20;
        add.s32         %r51, %r16, -54;
 
-BB113_2:
+BB115_2:
        add.s32         %r52, %r51, -1023;
        and.b32         %r17, %r50, -2146435073;
        or.b32          %r18, %r17, 1072693248;
        mov.b64         %fd135, {%r49, %r18};
        setp.lt.u32     %p2, %r18, 1073127583;
-       @%p2 bra        BB113_4;
+       @%p2 bra        BB115_4;
 
        {
        .reg .b32 %temp; 
@@ -13150,7 +13250,7 @@ BB113_2:
        mov.b64         %fd135, {%r19, %r21};
        add.s32         %r52, %r51, -1022;
 
-BB113_4:
+BB115_4:
        add.f64         %fd15, %fd135, 0d3FF0000000000000;
        rcp.approx.ftz.f64      %fd16, %fd15;
        neg.f64         %fd17, %fd15;
@@ -13313,13 +13413,13 @@ BB113_4:
        mov.b32          %f2, %r35;
        abs.f32         %f1, %f2;
        setp.lt.f32     %p4, %f1, 0f4086232B;
-       @%p4 bra        BB113_7;
+       @%p4 bra        BB115_7;
 
        setp.lt.f64     %p5, %fd4, 0d0000000000000000;
        add.f64         %fd129, %fd4, 0d7FF0000000000000;
        selp.f64        %fd136, 0d0000000000000000, %fd129, %p5;
        setp.geu.f32    %p6, %f1, 0f40874800;
-       @%p6 bra        BB113_7;
+       @%p6 bra        BB115_7;
 
        mov.f64         %fd134, 0d4338000000000000;
        mov.f64         %fd133, 0d3FF71547652B82FE;
@@ -13341,26 +13441,26 @@ BB113_4:
        mov.b64         %fd131, {%r44, %r43};
        mul.f64         %fd136, %fd130, %fd131;
 
-BB113_7:
+BB115_7:
        {
        .reg .b32 %temp; 
        mov.b64         {%temp, %r45}, %fd136;
        }
        and.b32         %r46, %r45, 2147483647;
        setp.ne.s32     %p7, %r46, 2146435072;
-       @%p7 bra        BB113_9;
+       @%p7 bra        BB115_9;
 
        {
        .reg .b32 %temp; 
        mov.b64         {%r47, %temp}, %fd136;
        }
        setp.eq.s32     %p8, %r47, 0;
-       @%p8 bra        BB113_10;
+       @%p8 bra        BB115_10;
 
-BB113_9:
+BB115_9:
        fma.rn.f64      %fd136, %fd136, %fd5, %fd136;
 
-BB113_10:
+BB115_10:
        st.param.f64    [func_retval0+0], %fd136;
        ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/5ca8706e/src/main/java/org/apache/sysml/hops/ReorgOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java 
b/src/main/java/org/apache/sysml/hops/ReorgOp.java
index eb5d825..4fa782d 100644
--- a/src/main/java/org/apache/sysml/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java
@@ -132,9 +132,11 @@ public class ReorgOp extends MultiThreadedHop
                                        return true;
                                }
                        }
+                       case RESHAPE: {
+                               return true;
+                       }
                        case DIAG:
                        case REV:
-                       case RESHAPE:
                        case SORT:
                                return false;
                        default:

http://git-wip-us.apache.org/repos/asf/systemml/blob/5ca8706e/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 01b10a8..c90f9f9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -30,6 +30,7 @@ import 
org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.MatrixIndexingGPUInstruction;
 import 
org.apache.sysml.runtime.instructions.gpu.MatrixMatrixAxpyGPUInstruction;
+import org.apache.sysml.runtime.instructions.gpu.MatrixReshapeGPUInstruction;
 import 
org.apache.sysml.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
 import org.apache.sysml.runtime.instructions.gpu.MMTSJGPUInstruction;
 import 
org.apache.sysml.runtime.instructions.gpu.RelationalBinaryGPUInstruction;
@@ -69,6 +70,7 @@ public class GPUInstructionParser  extends InstructionParser
 
                // Reorg/Transpose
                String2GPUInstructionType.put( "r'",    
GPUINSTRUCTION_TYPE.Reorg);
+               String2GPUInstructionType.put( 
"rshape",GPUINSTRUCTION_TYPE.MatrixReshape);
 
                // Matrix Manipulation
                String2GPUInstructionType.put( "append", 
GPUINSTRUCTION_TYPE.Append);
@@ -193,6 +195,9 @@ public class GPUInstructionParser  extends InstructionParser
                        case Reorg:
                                return 
ReorgGPUInstruction.parseInstruction(str);
                                
+                       case MatrixReshape:
+                               return 
MatrixReshapeGPUInstruction.parseInstruction(str);
+                               
                        case ArithmeticBinary:
                                String opcode = InstructionUtils.getOpCode(str);
                                if( opcode.equals("+*") || opcode.equals("-*")  
)

http://git-wip-us.apache.org/repos/asf/systemml/blob/5ca8706e/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
index f865f9b..e3c444a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
@@ -41,6 +41,7 @@ public abstract class GPUInstruction extends Instruction {
                Dnn,
                MMTSJ,
                Reorg,
+               MatrixReshape,
                Append,
                ArithmeticBinary,
                BuiltinUnary,

http://git-wip-us.apache.org/repos/asf/systemml/blob/5ca8706e/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
new file mode 100644
index 0000000..61cb643
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.instructions.gpu;
+
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.functionobjects.SwapIndex;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.BooleanObject;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysml.utils.GPUStatistics;
+
+import jcuda.Pointer;
+
+public class MatrixReshapeGPUInstruction extends GPUInstruction {
+       
+       private final CPOperand _input;
+       private final CPOperand _output;
+       private final CPOperand _opRows;
+       private final CPOperand _opCols;
+       private final CPOperand _opByRow;
+       
+       protected MatrixReshapeGPUInstruction(Operator op, String opcode, 
String istr, 
+                       CPOperand in1, CPOperand in2, CPOperand in3, CPOperand 
in4, CPOperand out) {
+               super(op, opcode, istr);
+               _input = in1;
+               _opRows = in2;
+               _opCols = in3;
+               _opByRow = in4;
+               _output = out;
+       }
+       
+       public static MatrixReshapeGPUInstruction parseInstruction ( String str 
) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               InstructionUtils.checkNumFields( parts, 5 );
+               String opcode = parts[0];
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               CPOperand in3 = new CPOperand(parts[3]);
+               CPOperand in4 = new CPOperand(parts[4]);
+               CPOperand out = new CPOperand(parts[5]);
+               if(!opcode.equalsIgnoreCase("rshape"))
+                       throw new DMLRuntimeException("Unknown opcode while 
parsing an MatrixReshapeGPUInstruction: " + str);
+               else
+                       return new MatrixReshapeGPUInstruction(new 
ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, str, in1, in2, in3, 
in4, out);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               int rows = (int)ec.getScalarInput(_opRows.getName(), 
_opRows.getValueType(), _opRows.isLiteral()).getLongValue(); //save cast
+               int cols = (int)ec.getScalarInput(_opCols.getName(), 
_opCols.getValueType(), _opCols.isLiteral()).getLongValue(); //save cast
+               BooleanObject byRow = (BooleanObject) 
ec.getScalarInput(_opByRow.getName(), ValueType.BOOLEAN, _opByRow.isLiteral());
+               
+               GPUStatistics.incrementNoOfExecutedGPUInst();
+               String instName = getExtendedOpcode();
+               GPUContext gCtx = ec.getGPUContext(0); 
+               MatrixObject mat = getMatrixInputForGPUInstruction(ec, 
_input.getName());
+               if(rows*cols != mat.getNumRows()*mat.getNumColumns()) {
+                       throw new DMLRuntimeException("Incorrect number of rows 
and cols in rshape instruction");
+               }
+               // We currently support only dense rshape
+               Pointer inPtr = LibMatrixCUDA.getDensePointer(gCtx, mat, 
instName);
+               MatrixObject out = 
LibMatrixCUDA.getDenseMatrixOutputForGPUInstruction(ec, instName, 
_output.getName(), rows, cols);
+               Pointer outPtr = LibMatrixCUDA.getDensePointer(gCtx, out, 
instName);
+               if(byRow.getBooleanValue()) {
+                       // byrow = TRUE is simple memcpy and metadata update
+                       LibMatrixCUDA.deviceCopy(instName, inPtr, outPtr, 
LibMatrixCUDA.toInt(mat.getNumRows()), 
LibMatrixCUDA.toInt(mat.getNumColumns()));
+               }
+               else  {
+                       // byrow = FALSE uses a custom kernel to perform rshape
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("colwise_reshape", 
+                               
ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
+                               inPtr, outPtr, LibMatrixCUDA.toInt(rows*cols), 
+                               LibMatrixCUDA.toInt(mat.getNumRows()), 
LibMatrixCUDA.toInt(mat.getNumColumns()),
+                               rows, cols);
+               }
+               ec.releaseMatrixInputForGPUInstruction(_input.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+       }
+
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/5ca8706e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 464c4c2..217acd6 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -1542,7 +1542,7 @@ public class LibMatrixCUDA {
         * @param rlen number of rows
         * @param clen number of columns
         */
-       private static void deviceCopy(String instName, Pointer src, Pointer 
dest, int rlen, int clen) {
+       public static void deviceCopy(String instName, Pointer src, Pointer 
dest, int rlen, int clen) {
                long t0=0;
                if (DMLScript.FINEGRAINED_STATISTICS) t0 = System.nanoTime();
                int size = rlen * clen * sizeOfDataType;
@@ -2512,7 +2512,7 @@ public class LibMatrixCUDA {
         * @param numCols number of columns of output matrix object
         * @return      the matrix object
         */
-       protected static MatrixObject 
getDenseMatrixOutputForGPUInstruction(ExecutionContext ec, String instName, 
String name, long numRows, long numCols) {
+       public static MatrixObject 
getDenseMatrixOutputForGPUInstruction(ExecutionContext ec, String instName, 
String name, long numRows, long numCols) {
                long t0=0;
                if (DMLScript.FINEGRAINED_STATISTICS) t0 = System.nanoTime();
                Pair<MatrixObject, Boolean> mb = 
ec.getDenseMatrixOutputForGPUInstruction(name, numRows, numCols);

http://git-wip-us.apache.org/repos/asf/systemml/blob/5ca8706e/src/test/java/org/apache/sysml/test/gpu/ReshapeTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/ReshapeTest.java 
b/src/test/java/org/apache/sysml/test/gpu/ReshapeTest.java
new file mode 100644
index 0000000..199a722
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/ReshapeTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests gpu reshape
+ */
+public class ReshapeTest extends GPUTests {
+
+       private final static String TEST_NAME = "ReshapeTests";
+       private final int seed = 42;
+
+       @Override
+       public void setUp() {
+               super.setUp();
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_DIR, TEST_NAME);
+               getAndLoadTestConfiguration(TEST_NAME);
+       }
+
+       @Test
+       public void testDenseReshape1() {
+               testReshape(1, 10, 10, 1, true, 0.9);
+       }
+       
+       @Test
+       public void testDenseReshape2() {
+               testReshape(1, 10, 10, 1, false, 0.9);
+       }
+       
+       @Test
+       public void testDenseReshape5() {
+               testReshape(10, 3, 3, 10, true, 0.9);
+       }
+       
+       @Test
+       public void testDenseReshape6() {
+               testReshape(10, 3, 3, 10, false, 0.9);
+       }
+       
+       @Test
+       public void testDenseReshape3() {
+               testReshape(10, 3, 15, 2, true, 0.9);
+       }
+       
+       @Test
+       public void testDenseReshape4() {
+               testReshape(10, 3, 15, 2, false, 0.9);
+       }
+       
+       @Test
+       public void testSparseReshape7() {
+               testReshape(10, 3, 15, 2, true, 0.1);
+       }
+       
+       @Test
+       public void testSparseReshape8() {
+               testReshape(10, 3, 15, 2, false, 0.1);
+       }
+       
+       private void testReshape(int inRows, int inCols, int outRows, int 
outCols, boolean byrow, double sparsity) {
+               System.out.println("Starting testReshape:" + inRows + " " + 
inCols + " " + outRows + " " + outCols + " " + byrow + " " + sparsity);
+               String scriptStr = "output = matrix(x, rows=" + outRows + ", 
cols=" + outCols +  ", byrow=" +  (byrow ? "TRUE" : "FALSE") + ");" ;
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 
10, sparsity, seed));
+               List<String> outputs = Arrays.asList("output");
+               List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, 
outputs);
+               List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
+               assertHeavyHitterPresent("gpu_rshape");
+               assertEqualObjects(outCPU.get(0), outGPU.get(0));
+       }
+}

Reply via email to