Repository: systemml Updated Branches: refs/heads/master 15ecb723e -> 99b1c2e25
[SYSTEMML-2109] Codegen support for maxpool/avgpool DNN operations This patch adds code generation support for maxpool and avgpool DNN operations to the codegen row-template. This way often entire joins of conv/maxpool/relu can be executed as fused operators without parallelization barriers per operator. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/99b1c2e2 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/99b1c2e2 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/99b1c2e2 Branch: refs/heads/master Commit: 99b1c2e252f935bbce5f53574d3e245221da3e68 Parents: 15ecb72 Author: Matthias Boehm <[email protected]> Authored: Tue Jul 24 18:32:50 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Jul 24 18:57:00 2018 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/DnnOp.java | 6 ++ .../apache/sysml/hops/codegen/cplan/CNode.java | 24 ++++++++ .../sysml/hops/codegen/cplan/CNodeNary.java | 58 ++++++++++++++++++-- .../sysml/hops/codegen/cplan/CNodeUnary.java | 22 +------- .../hops/codegen/template/TemplateRow.java | 14 ++++- .../runtime/codegen/LibSpoofPrimitives.java | 40 ++++++++++++++ .../matrix/data/LibMatrixDNNPooling.java | 16 +++--- .../functions/codegen/RowAggTmplTest.java | 6 +- .../scripts/functions/codegen/rowAggPattern44.R | 1 + .../functions/codegen/rowAggPattern44.dml | 1 + 10 files changed, 154 insertions(+), 34 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/DnnOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java index 8dbbeda..3b48371 100644 --- a/src/main/java/org/apache/sysml/hops/DnnOp.java +++ b/src/main/java/org/apache/sysml/hops/DnnOp.java @@ -217,6 +217,12 @@ public class DnnOp extends MultiThreadedHop isEqualAndKnown(param1.H, param2.H) && isEqualAndKnown(param1.W, param2.W); } + public boolean isStride1Pad0() { + DnnParameters tmp = parseInput(); + return tmp.stride_h == 1 && tmp.stride_w == 1 + && tmp.pad_h == 0 && tmp.pad_w == 0; + } + private static boolean isEqualAndKnown(int val1, int val2) { return val1 >= 0 && val2 >= 0 && val1 == val2; } http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java index b0efb42..6dab878 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; +import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.util.UtilFunctions; @@ -204,4 +205,27 @@ public abstract class CNode && _dataType == cthat._dataType && _literal == cthat._literal; } + + protected String replaceUnaryPlaceholders(String tmp, String varj, boolean vectIn) { + //replace sparse and dense inputs + tmp = tmp.replace("%IN1v%", varj+"vals"); + tmp = tmp.replace("%IN1i%", varj+"ix"); + tmp = tmp.replace("%IN1%", + (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? varj + ".values(rix)" : + (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? varj + ".values(0)" : varj)); + + //replace start position of main input + String spos = (_inputs.get(0) instanceof CNodeData + && _inputs.get(0).getDataType().isMatrix()) ? !varj.startsWith("b") ? + varj+"i" : TemplateUtils.isMatrix(_inputs.get(0)) ? varj + ".pos(rix)" : "0" : "0"; + + tmp = tmp.replace("%POS1%", spos); + tmp = tmp.replace("%POS2%", spos); + + //replace length + if( _inputs.get(0).getDataType().isMatrix() ) + tmp = tmp.replace("%LEN%", _inputs.get(0).getVectorLength()); + + return tmp; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java index 7f19194..e720601 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java @@ -20,15 +20,21 @@ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; +import java.util.List; +import org.apache.commons.lang3.StringUtils; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.util.DnnUtils; import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeNary extends CNode { public enum NaryType { - VECT_CBIND; + VECT_CBIND, + VECT_MAX_POOL, + VECT_AVG_POOL; + public static boolean contains(String value) { for( NaryType bt : values() ) if( bt.name().equals(value) ) @@ -56,12 +62,19 @@ public class CNodeNary extends CNode off += input._cols; } return sb.toString(); + case VECT_MAX_POOL: + case VECT_AVG_POOL: + String vectName = (this==VECT_MAX_POOL) ? "Maxpool" : "Avgpool"; + String paramStr = getPoolingParameterString(inputs); + return sparseGen ? + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len, "+paramStr+");\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%, "+paramStr+");\n"; default: throw new RuntimeException("Invalid nary type: "+this.toString()); } } public boolean isVectorPrimitive() { - return this == VECT_CBIND; + return this == VECT_CBIND || this == VECT_MAX_POOL || this == VECT_AVG_POOL; } } @@ -90,10 +103,17 @@ public class CNodeNary extends CNode sb.append(in.codegen(sparse)); //generate nary operation (use sparse template, if data input) + boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData + && _inputs.get(0).getVarname().startsWith("a") + && !_inputs.get(0).isLiteral()); String var = createVarname(); - String tmp = _type.getTemplate(sparse, _cols, _inputs); + String tmp = _type.getTemplate(lsparse, _cols, _inputs); tmp = tmp.replace("%TMP%", var); + //replace sparse and dense inputs + String varj = _inputs.get(0).getVarname(); + tmp = replaceUnaryPlaceholders(tmp, varj, false); + sb.append(tmp); //mark as generated @@ -105,7 +125,9 @@ public class CNodeNary extends CNode @Override public String toString() { switch(_type) { - case VECT_CBIND: return "n(cbind)"; + case VECT_CBIND: return "n(cbind)"; + case VECT_MAX_POOL: return "n(maxpool)"; + case VECT_AVG_POOL: return "n(avgpool)"; default: return "m("+_type.name().toLowerCase()+")"; } @@ -121,6 +143,19 @@ public class CNodeNary extends CNode _cols += in._cols; _dataType = DataType.MATRIX; break; + case VECT_MAX_POOL: + case VECT_AVG_POOL: //only stride 1, pad 0 + int C = Integer.parseInt(_inputs.get(6).getVarname()); + int H = Integer.parseInt(_inputs.get(7).getVarname()); + int W = Integer.parseInt(_inputs.get(8).getVarname()); + int R = Integer.parseInt(_inputs.get(11).getVarname()); + int S = Integer.parseInt(_inputs.get(12).getVarname()); + long P = DnnUtils.getP(H, R, 1, 0); + long Q = DnnUtils.getQ(W, S, 1, 0); + _rows = _inputs.get(0)._rows; + _cols = C * P * Q; + _dataType = DataType.MATRIX; + break; } } @@ -142,4 +177,19 @@ public class CNodeNary extends CNode return super.equals(that) && _type == that._type; } + + private static String getPoolingParameterString(List<CNode> inputs) { + //extract and derive individual parameters + int C = Integer.parseInt(inputs.get(6).getVarname()); + int H = Integer.parseInt(inputs.get(7).getVarname()); + int W = Integer.parseInt(inputs.get(8).getVarname()); + int R = Integer.parseInt(inputs.get(11).getVarname()); + int S = Integer.parseInt(inputs.get(12).getVarname()); + int P = (int) DnnUtils.getP(H, R, 1, 0); + int Q = (int) DnnUtils.getQ(W, S, 1, 0); + + //construct parameter string + return "rix, " + StringUtils.join( + new int[]{C, P, Q, R, S, H, W}, ','); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index ba41fad..21f7fe7 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -23,7 +23,6 @@ import java.util.Arrays; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.StringUtils; -import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.util.UtilFunctions; @@ -214,27 +213,10 @@ public class CNodeUnary extends CNode String tmp = _type.getTemplate(lsparse); tmp = tmp.replace("%TMP%", var); - String varj = _inputs.get(0).getVarname(); - //replace sparse and dense inputs + String varj = _inputs.get(0).getVarname(); boolean vectIn = varj.startsWith("b") && !_type.isScalarLookup(); - tmp = tmp.replace("%IN1v%", varj+"vals"); - tmp = tmp.replace("%IN1i%", varj+"ix"); - tmp = tmp.replace("%IN1%", - (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? varj + ".values(rix)" : - (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? varj + ".values(0)" : varj)); - - //replace start position of main input - String spos = (_inputs.get(0) instanceof CNodeData - && _inputs.get(0).getDataType().isMatrix()) ? !varj.startsWith("b") ? - varj+"i" : TemplateUtils.isMatrix(_inputs.get(0)) ? varj + ".pos(rix)" : "0" : "0"; - - tmp = tmp.replace("%POS1%", spos); - tmp = tmp.replace("%POS2%", spos); - - //replace length - if( _inputs.get(0).getDataType().isMatrix() ) - tmp = tmp.replace("%LEN%", _inputs.get(0).getVectorLength()); + tmp = replaceUnaryPlaceholders(tmp, varj, vectIn); sb.append(tmp); http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 9df67d0..9885909 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -115,7 +115,9 @@ public class TemplateRow extends TemplateBase && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)hop)) || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() - && hop.getInput().get(0).getDim2()>1); + && hop.getInput().get(0).getDim2()>1) + || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL) + && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0()); } @Override @@ -140,6 +142,8 @@ public class TemplateRow extends TemplateBase || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(0).getDim2()>1) + || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL) + && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0()) || isPartOfValidCumAggChain(hop) //cum* with transpose || isPartOfValidTransposeMMChain(hop)); //t(f(X))%*%X } @@ -156,6 +160,8 @@ public class TemplateRow extends TemplateBase || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(0).getDim2()>1 ) + || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL) + && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0()) || (HopRewriteUtils.isDataGenOpWithLiteralInputs(input, DataGenMethod.SEQ) && HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false)) || (hop instanceof AggBinaryOp @@ -476,6 +482,12 @@ public class TemplateRow extends TemplateBase out = new CNodeBinary(cdata1, cdata2, BinType.valueOf("VECT_"+((DnnOp)hop).getOp().name())); } + else if( HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL) ) { + CNode[] in = hop.getInput().stream().map(h -> + tmp.get(h.getHopID())).toArray(CNode[]::new); + out = new CNodeNary(in, CNodeNary.NaryType + .valueOf("VECT_"+((DnnOp)hop).getOp().name())); + } else if( hop instanceof NaryOp ) { CNode[] inputs = new CNode[hop.getInput().size()]; for( int i=0; i<hop.getInput().size(); i++ ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java index fc0c1d2..c1460ce 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java @@ -26,7 +26,9 @@ import org.apache.sysml.runtime.functionobjects.BitwAnd; import org.apache.sysml.runtime.functionobjects.IntegerDivide; import org.apache.sysml.runtime.functionobjects.Modulus; import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNNPooling; import org.apache.sysml.runtime.matrix.data.LibMatrixMult; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType; /** * This library contains all vector primitives that are used in @@ -2052,7 +2054,45 @@ public class LibSpoofPrimitives LibMatrixDNN.multBias(c, b, 1, b.length, len/b.length); return c; } + + //maxpool + + public static double[] vectMaxpoolWrite(double[] a, int ai, int len, int rix, int C, int P, int Q, int R, int S, int H, int W) { + double[] c = allocVector(C*P*Q, true); + LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.MAX, + -Double.MAX_VALUE, 1, a, c, rix, rix+1, ai, 0, C, P, Q, R, S, H, W); + return c; + } + + public static double[] vectMaxpoolWrite(double[] avals, int[] aix, int ai, int alen, int len, int rix, int C, int P, int Q, int R, int S, int H, int W) { + double[] a = allocVector(len, true); + double[] c = allocVector(C*P*Q, true); + for(int k=ai; k<ai+alen; k++) + a[aix[k]] = avals[k]; + LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.MAX, + -Double.MAX_VALUE, 1, a, c, rix, rix+1, 0, 0, C, P, Q, R, S, H, W); + return c; + } + + //avgpool + public static double[] vectAvgpoolWrite(double[] a, int ai, int len, int rix, int C, int P, int Q, int R, int S, int H, int W) { + double[] c = allocVector(C*P*Q, true); + LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.AVG, + 0, 1/(R*S), a, c, rix, rix+1, ai, 0, C, P, Q, R, S, H, W); + return c; + } + + public static double[] vectAvgpoolWrite(double[] avals, int[] aix, int ai, int alen, int len, int rix, int C, int P, int Q, int R, int S, int H, int W) { + double[] a = allocVector(len, true); + double[] c = allocVector(C*P*Q, true); + for(int k=ai; k<ai+alen; k++) + a[aix[k]] = avals[k]; + LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.AVG, + 0, 1/(R*S), a, c, rix, rix+1, 0, 0, C, P, Q, R, S, H, W); + return c; + } + //complex builtin functions that are not directly generated //(included here in order to reduce the number of imports) http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java index 4ff8e5e..7fb33a4 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java @@ -97,8 +97,8 @@ public class LibMatrixDNNPooling { return ret; } - public static void poolingDenseStride1Pad0(PoolingType pType, double minVal, double pFact, - double[] in, double[] out, int rl, int ru, int C, int P, int Q, int R, int S, int H, int W) { + public static void poolingDenseStride1Pad0(PoolingType pType, double minVal, double pFact, double[] in, + double[] out, int rl, int ru, int ii, int oi, int C, int P, int Q, int R, int S, int H, int W) { boolean max = (pType == PoolingType.MAX); int CHW = C * H * W; @@ -106,9 +106,9 @@ public class LibMatrixDNNPooling { //quick-path w/o materialized index arrays and //simplified inner loops for P = 1, Q = 1, W = 1 int lenh = Math.min(R,H); - for(int i = rl, oix=rl*C; i < ru; i++, oix+=C) - for (int c = 0, off=i*CHW; c < C; c++, off+=H) { - out[oix+c] = max ? max(minVal, in, off, lenh) : + for(int i = rl; i < ru; i++, oi+=C) + for (int c = 0, off=ii+(i-rl)*CHW; c < C; c++, off+=H) { + out[oi+c] = max ? max(minVal, in, off, lenh) : avg(minVal, in, off, lenh, pFact); } } @@ -117,7 +117,7 @@ public class LibMatrixDNNPooling { Arrays.fill(out, rl*CPQ, ru*CPQ, minVal); //quick-path w/o materialized index arrays for(int i = rl; i < ru; i++) - for (int c = 0, off=i*CHW, oix=i*CPQ; c < C; c++, off+=HW) + for (int c = 0, off=ii+(i-rl)*CHW, oix=oi; c < C; c++, off+=HW) for (int p = 0; p < P; p++, oix+=Q) for (int h = p; h < Math.min(p+R,H); h++) for (int q = 0, off2=off+h*W; q < Q; q++) { @@ -139,7 +139,7 @@ public class LibMatrixDNNPooling { _rl = rl; _ru = ru; _params = params; _poolingType = poolingType; - _poolingMultiplier = Math.pow(params.R*params.S, -1); + _poolingMultiplier = 1/(params.R*params.S); } @Override @@ -157,7 +157,7 @@ public class LibMatrixDNNPooling { if( _params.isStride1Pad0() ) { poolingDenseStride1Pad0(_poolingType, minValForMaxPoolOperations, - _poolingMultiplier, in, out, _rl, _ru, C, P, Q, R, S, H, W); + _poolingMultiplier, in, out, _rl, _ru, _rl*CHW, _rl*CPQ, C, P, Q, R, S, H, W); } else { //general case //thread-local initialization of output block http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index 04891d0..48555ae 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -80,7 +80,7 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME41 = TEST_NAME+"41"; //X*rowSums(X/seq(1,N)+t(seq(M,1))) private static final String TEST_NAME42 = TEST_NAME+"42"; //X/rowSums(min(X, Y, Z)) private static final String TEST_NAME43 = TEST_NAME+"43"; //bias_add(X,B) + bias_mult(X,B) - private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - mean(X)); + private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - mean(X)) + 7; private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -817,6 +817,10 @@ public class RowAggTmplTest extends AutomatedTestBase if( testname.equals(TEST_NAME42) ) Assert.assertTrue(!heavyHittersContainsSubString("min","nmin") && !heavyHittersContainsSubString("spoof", 2)); + if( testname.equals(TEST_NAME44) ) + Assert.assertTrue(!heavyHittersContainsSubString("maxpooling") + && !heavyHittersContainsSubString("spoof", 2)); + } finally { rtplatform = platformOld; http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/scripts/functions/codegen/rowAggPattern44.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern44.R b/src/test/scripts/functions/codegen/rowAggPattern44.R index 99ba0b0..7269df3 100644 --- a/src/test/scripts/functions/codegen/rowAggPattern44.R +++ b/src/test/scripts/functions/codegen/rowAggPattern44.R @@ -95,5 +95,6 @@ max_pool <- function(X, N, C, Hin, Win, Hf, Wf, } R = max_pool(X, numImg, numChannels, imgSize*imgSize, 1, poolSize1, poolSize2, stride, stride) +R = R + 7; writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep="")) http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/scripts/functions/codegen/rowAggPattern44.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern44.dml b/src/test/scripts/functions/codegen/rowAggPattern44.dml index f236451..f5e7b6c 100644 --- a/src/test/scripts/functions/codegen/rowAggPattern44.dml +++ b/src/test/scripts/functions/codegen/rowAggPattern44.dml @@ -31,5 +31,6 @@ while(FALSE){} X = X - rowMeans(X); R = max_pool(X, stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize*imgSize, 1], pool_size=[poolSize1, poolSize2]); +R = R + 7; write(R, $1, format="text");
