Repository: incubator-systemml Updated Branches: refs/heads/master 6e7e8873a -> d127dfa2d
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d127dfa2/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index 6b5c14b..2547a87 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -58,7 +58,7 @@ public class LibMatrixDNN { MaxPooling_Forward, MaxPooling_Backward, // Alternate approaches that we tried but the performance was unsatisfactory be included: direct, non-looped im2col LoopedIm2ColConv2d, LoopedIm2ColConv2dBwdFilter, LoopedIm2ColConv2dBwdData, - BiasAdd, ReluBackward + BiasAdd, ReluBackward, BiasMultiply } // ------------------------------------------------------------------------------------------------ @@ -731,6 +731,96 @@ public class LibMatrixDNN { params.output.recomputeNonZeros(); } + + /** + * Performs the operation corresponding to the DML script: + * ones = matrix(1, rows=1, cols=Hout*Wout) + * output = input * matrix(bias %*% ones, rows=1, cols=F*Hout*Wout) + * This operation is often followed by conv2d and hence we have introduced bias_multiply(input, bias) built-in function + * + * @param input input matrix + * @param bias bias matrix + * @param outputBlock output matrix + * @param numThreads number of threads + * @throws DMLRuntimeException if DMLRuntimeException occurs + */ + public static void biasMultiply(MatrixBlock input, MatrixBlock bias, MatrixBlock outputBlock, int numThreads) throws DMLRuntimeException { + int N = input.getNumRows(); + int K = bias.getNumRows(); + int PQ = input.getNumColumns() / K; + + ConvolutionParameters params = new ConvolutionParameters(N, PQ, -1, -1, K, -1, -1, -1, -1, -1, -1, numThreads); + params.input1 = input; + params.input2 = bias; + params.output = outputBlock; + + if(!input.isInSparseFormat() && TEST_SPARSE_INPUT) { + input.denseToSparse(); + } + if(!bias.isInSparseFormat() && TEST_SPARSE_FILTER) { + bias.denseToSparse(); + } + + if(bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) { + throw new DMLRuntimeException("Incorrect inputs for bias_multiply: input[" + N + " X " + input.getNumColumns() + "] and bias[" + K + " X " + bias.getNumColumns() + "]"); + } + + if(!input.isEmptyBlock() && !bias.isEmptyBlock()) { + runConvTask(TaskType.BiasMultiply, params); + //post-processing: maintain nnz + params.output.recomputeNonZeros(); + } + else { + params.output.setNonZeros(0); + } + } + + private static void doBiasMultiply(ConvolutionParameters params, int rl, int ru) throws DMLRuntimeException { + double [] outputArray = params.output.getDenseBlock(); + int PQ = params.C; + int numOutCols = params.input1.getNumColumns(); + + if(!params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) { + double [] inputArr = params.input1.getDenseBlock(); + double [] biasArr = params.input2.getDenseBlock(); + int K = params.K; + int index = rl*K*PQ; + for(int n = rl; n < ru; n++) { + for(int k = 0; k < K; k++) { + for(int pq = 0; pq < PQ; pq++, index++) { + outputArray[index] = inputArr[index] * biasArr[k]; + } + } + } + } + else { + // Fill non-zero values + if(params.input1.isInSparseFormat()) { + Iterator<IJV> iter = params.input1.sparseBlock.getIterator(rl, ru); + while(iter.hasNext()) { + IJV ijv = iter.next(); + int i = ijv.getI(); + int j = ijv.getJ(); + outputArray[i*numOutCols + j] = ijv.getV(); + } + } + else { + System.arraycopy(params.input1.getDenseBlock(), 0, outputArray, 0, outputArray.length); + } + int K = params.K; + int index = rl*K*PQ; + for(int k = 0; k < K; k++) { + double val = params.input2.getValue(k, 1); + for(int n = rl; n < ru; n++) { + for(int pq = 0; pq < PQ; pq++, index++) { + outputArray[index] *= val; + } + } + } + } + + } + private static void doBiasAdd(ConvolutionParameters params, int rl, int ru) throws DMLRuntimeException { double [] outputArray = params.output.getDenseBlock(); int PQ = params.C; @@ -1009,6 +1099,9 @@ public class LibMatrixDNN { case BiasAdd: doBiasAdd(_params, _rl, _ru); break; + case BiasMultiply: + doBiasMultiply(_params, _rl, _ru); + break; case ReluBackward: lnnz = doReluBackward(_params, _rl, _ru); break;
