[SYSTEMML-1979] Improved codegen optimizer (cost model, various fixes) This patch makes a number of improvements to the codegen optimizer, which help to exploit missed fusion potential for Kmeans over large distributed datasets (i.e., with spark codegen operations). In detail, this includes the following changes:
1) Eviction-aware cost model: So far we only took the write memory bandwidth into account. With this change we also account for known evictions whenever the output and temporary intermediate inputs are known not to fit into the buffer pool. 2) Generalized exploration of row fusion plans: This generalization now allows to fuse matrix-matrix multiplications onto arbitrary row operations, which allows to fuse the entire Kmeans inner loop if beneficial. 3) Row sumSq vector primitives: Additionally, we now compile sumSq vector primitives instead of sum(pow(,2)) which helps to avoid unnecessary dense row vector intermediates. 4) Fix missing dense-sparse outer vector operations: So far we only supported sparse-dense outer vector operations. With the above change (2), the sparse input can also occur on the right hand side. 5) Fix cost model for unary aggregates: The compute costs for all types of unary aggregates were incorrectly computed based on the output size instead of based on the input size. 6) Fix row to cell conversion: This patch also makes some smaller corrections for the conversion of row to cell templates if there are no aggregations or vector operations (e.g., only convert if a cell template exist with exactly the same fusion references). 7) Fix robustness of temporary memory management: We use a preallocated ring buffer for row vector intermediates of different sizes. This patch restricts the size of preallocated vectors to 4MB to avoid OOM in case the large vectors are not used by an operator. On a 200M x 100, dense input matrix (160GB), this patch improved the end-to-end runtime of Kmeans (20 iterations, 5 centroids) w/ codegen from 5319s to 281s. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d907efc1 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d907efc1 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d907efc1 Branch: refs/heads/master Commit: d907efc17456d7536e1a7344a614aa8a122721ee Parents: d916ba5 Author: Matthias Boehm <mboe...@gmail.com> Authored: Sun Oct 29 23:21:30 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Mon Oct 30 18:37:34 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeBinary.java | 3 +- .../sysml/hops/codegen/cplan/CNodeUnary.java | 11 +-- .../opt/PlanSelectionFuseCostBasedV2.java | 72 +++++++++++++------- .../hops/codegen/template/CPlanMemoTable.java | 10 +++ .../hops/codegen/template/TemplateRow.java | 47 ++++++++++--- .../runtime/codegen/LibSpoofPrimitives.java | 27 ++++++++ .../gpu/ConvolutionGPUInstruction.java | 5 -- .../functions/codegen/RowAggTmplTest.java | 18 ++++- .../scripts/functions/codegen/rowAggPattern33.R | 36 ++++++++++ .../functions/codegen/rowAggPattern33.dml | 33 +++++++++ 10 files changed, 215 insertions(+), 47 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index d188afd..8c3c73d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -77,7 +77,8 @@ public class CNodeBinary extends CNode return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; case VECT_OUTERMULT_ADD: - return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : + return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : + sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n"; //vector-scalar-add operations http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index 3a3dc79..891bfb9 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -30,7 +30,7 @@ public class CNodeUnary extends CNode { public enum UnaryType { LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific - ROW_SUMS, ROW_MINS, ROW_MAXS, ROW_COUNTNNZS, //codegen specific + ROW_SUMS, ROW_SUMSQS, ROW_MINS, ROW_MAXS, ROW_COUNTNNZS, //codegen specific VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG, VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN, VECT_SIN, VECT_COS, VECT_TAN, VECT_ASIN, VECT_ACOS, VECT_ATAN, @@ -51,6 +51,7 @@ public class CNodeUnary extends CNode public String getTemplate(boolean sparse) { switch( this ) { case ROW_SUMS: + case ROW_SUMSQS: case ROW_MINS: case ROW_MAXS: case ROW_COUNTNNZS: { @@ -242,9 +243,10 @@ public class CNodeUnary extends CNode @Override public String toString() { switch(_type) { - case ROW_SUMS: return "u(R+)"; - case ROW_MINS: return "u(Rmin)"; - case ROW_MAXS: return "u(Rmax)"; + case ROW_SUMS: return "u(R+)"; + case ROW_SUMSQS: return "u(Rsq+)"; + case ROW_MINS: return "u(Rmin)"; + case ROW_MAXS: return "u(Rmax)"; case ROW_COUNTNNZS: return "u(Rnnz)"; case VECT_EXP: case VECT_POW2: @@ -308,6 +310,7 @@ public class CNodeUnary extends CNode break; case ROW_SUMS: + case ROW_SUMSQS: case ROW_MINS: case ROW_MAXS: case ROW_COUNTNNZS: http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 10875e8..4d8a7bc 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -42,6 +42,7 @@ import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.IndexingOp; @@ -60,6 +61,7 @@ import org.apache.sysml.hops.codegen.template.TemplateRow; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.runtime.codegen.LibSpoofPrimitives; +import org.apache.sysml.runtime.controlprogram.caching.LazyWriteBuffer; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.util.UtilFunctions; @@ -79,23 +81,24 @@ import org.apache.sysml.utils.Statistics; * */ public class PlanSelectionFuseCostBasedV2 extends PlanSelection -{ +{ private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBasedV2.class.getName()); //common bandwidth characteristics, with a conservative write bandwidth in order //to cover result allocation, write into main memory, and potential evictions - private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024; //2GB/s - private static final double READ_BANDWIDTH = 32d*1024*1024*1024; //32GB/s - private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH/4; - private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core - * InfrastructureAnalyzer.getLocalParallelism(); + private static final double WRITE_BANDWIDTH_IO = 512*1024*1024; //512MB/s + private static final double WRITE_BANDWIDTH_MEM = 2d*1024*1024*1024; //2GB/s + private static final double READ_BANDWIDTH_MEM = 32d*1024*1024*1024; //32GB/s + private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH_MEM/4; + private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //1GFLOPs/core + * InfrastructureAnalyzer.getLocalParallelism(); //sparsity estimate for unknown sparsity to prefer sparse-safe fusion plans private static final double SPARSE_SAFE_SPARSITY_EST = 0.1; //optimizer configuration public static boolean COST_PRUNING = true; - public static boolean STRUCTURAL_PRUNING = false; + public static boolean STRUCTURAL_PRUNING = true; private static final IDSequence COST_ID = new IDSequence(); private static final TemplateRow ROW_TPL = new TemplateRow(); @@ -306,8 +309,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection matTargets.add(hopID); Hop hop = memo.getHopRefs().get(hopID); long size = getSize(hop); - costs += size * 8 / WRITE_BANDWIDTH + - size * 8 / READ_BANDWIDTH; + costs += size * 8 / WRITE_BANDWIDTH_MEM + + size * 8 / READ_BANDWIDTH_MEM; } } //points with non-partition consumers @@ -315,7 +318,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection if( !matTargets.contains(hopID) ) { matTargets.add(hopID); Hop hop = memo.getHopRefs().get(hopID); - costs += getSize(hop) * 8 / WRITE_BANDWIDTH; + costs += getSize(hop) * 8 / WRITE_BANDWIDTH_MEM; } return costs; @@ -326,7 +329,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //get partition input reads (at least read once) for( Long hopID : part.getInputs() ) { Hop hop = memo.getHopRefs().get(hopID); - costs += getSize(hop) * 8 / READ_BANDWIDTH; + costs += getSize(hop) * 8 / READ_BANDWIDTH_MEM; } return costs; } @@ -335,7 +338,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection double costs = 0; for( Long hopID : R ) { Hop hop = memo.getHopRefs().get(hopID); - costs += getSize(hop) * 8 / WRITE_BANDWIDTH; + costs += getSize(hop) * 8 / WRITE_BANDWIDTH_MEM; } return costs; } @@ -345,6 +348,13 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection .mapToDouble(d -> d/COMPUTE_BANDWIDTH).sum(); } + private static double sumTmpInputOutputSize(CPlanMemoTable memo, CostVector vect) { + //size of intermediate inputs and outputs, i.e., output and inputs other than treads + return vect.outSize + vect.inSizes.entrySet().stream() + .filter(e -> !HopRewriteUtils.isData(memo.getHopRefs().get(e.getKey()), DataOpTypes.TRANSIENTREAD)) + .mapToDouble(e -> e.getValue()).sum(); + } + private static long getSize(Hop hop) { return Math.max(hop.getDim1(),1) * Math.max(hop.getDim2(),1); @@ -593,6 +603,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection private static boolean isRowAggOp(Hop hop){ return (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp + || (hop instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)hop)) || HopRewriteUtils.isBinary(hop, OpOp2.CBIND)); } @@ -629,7 +640,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection HashSet<Long> refAggs = getRowAggOpsWithRowRef(memo, part); for( Long hopID : part.getPartition() ) { MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); - if( me != null && me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL) + if( me != null && me.type == TemplateType.ROW && memo.contains(hopID, me, TemplateType.CELL) && rIsRowTemplateWithoutAggOrVects(memo, memo.getHopRefs().get(hopID), new HashSet<Long>(), refAggs.contains(hopID)) ) { List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist)); @@ -829,12 +840,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //add costs for opened fused operator if( opened ) { - if( LOG.isTraceEnabled() ) { - String type = (best !=null) ? best.type.name() : "HOP"; - LOG.trace("Cost vector ("+type+" "+currentHopId+"): "+costVect); - } - double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH - + Math.max(costVect.getInputSize() * 8 / READ_BANDWIDTH, + double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH_MEM + + Math.max(costVect.getInputSize() * 8 / READ_BANDWIDTH_MEM, costVect.computeCosts/ COMPUTE_BANDWIDTH); //read correction for distributed computation Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); @@ -843,7 +850,15 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //sparsity correction for outer-product template (and sparse-safe cell) if( best != null && best.type == TemplateType.OUTER ) tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST; + //write correction for known evictions in CP + else if( driver.getMemEstimate() < OptimizerUtils.getLocalMemBudget() + && sumTmpInputOutputSize(memo, costVect) > LazyWriteBuffer.getWriteBufferSize() ) + tmpCosts += costVect.outSize * 8 / WRITE_BANDWIDTH_IO; costs += tmpCosts; + if( LOG.isTraceEnabled() ) { + String type = (best !=null) ? best.type.name() : "HOP"; + LOG.trace("Cost vector ("+type+" "+currentHopId+"): "+costVect+" -> "+tmpCosts); + } } //add costs for non-partition read in the middle of fused operator else if( part.getExtConsumed().contains(current.getHopID()) ) { @@ -985,13 +1000,18 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } else if( current instanceof AggUnaryOp) { switch(((AggUnaryOp)current).getOp()) { - case SUM: costs = 4; break; - case SUM_SQ: costs = 5; break; - case MIN: - case MAX: costs = 1; break; - default: - LOG.warn("Cost model not " - + "implemented yet for: "+((AggUnaryOp)current).getOp()); + case SUM: costs = 4; break; + case SUM_SQ: costs = 5; break; + case MIN: + case MAX: costs = 1; break; + default: + LOG.warn("Cost model not " + + "implemented yet for: "+((AggUnaryOp)current).getOp()); + } + switch(((AggUnaryOp)current).getDirection()) { + case Col: costs *= Math.max(current.getInput().get(0).getDim1(),1); break; + case Row: costs *= Math.max(current.getInput().get(0).getDim2(),1); break; + case RowCol: costs *= getSize(current.getInput().get(0)); break; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 5eedc7b..0c3bb90 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -87,6 +87,11 @@ public class CPlanMemoTable .anyMatch(p -> p.type==type); } + public boolean contains(long hopID, MemoTableEntry me, TemplateType type) { + return contains(hopID) && get(hopID).stream() + .anyMatch(p -> p.type==type && p.equalPlanRefs(me)); + } + public boolean contains(long hopID, boolean checkClose, TemplateType... type) { if( !checkClose && type.length==1 ) return contains(hopID, type[0]); @@ -408,6 +413,11 @@ public class CPlanMemoTable return (input1>=0) ? 0 : (input2>=0) ? 1 : (input3>=0) ? 2 : -1; } + public boolean equalPlanRefs(MemoTableEntry that) { + return (input1 == that.input1 + && input2 == that.input2 + && input3 == that.input3); + } public long input(int index) { return (index==0) ? input1 : (index==1) ? input2 : input3; } http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index e14fbd3..b862abf 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -115,7 +115,8 @@ public class TemplateRow extends TemplateBase || (hop instanceof AggBinaryOp && hop.dimsKnown() && isFuseSkinnyMatrixMult(hop) //MM && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) - || isPartOfValidCumAggChain(hop) ); //cum* with transpose + || isPartOfValidCumAggChain(hop) //cum* with transpose + || isPartOfValidTransposeMMChain(hop)); //t(f(X))%*%X } @Override @@ -176,6 +177,22 @@ public class TemplateRow extends TemplateBase && hop.getInput().get(0).getParent().size()==1); } } + + private static boolean isPartOfValidTransposeMMChain(Hop hop) { + //check if transpose is part of t(f(X))%*%X chain w/ single consumer + //for now: we restrict this to tall and skinny matrix multiplications + return HopRewriteUtils.isTransposeOperation(hop) + && hop.getParent().size() == 1 + && hop.dimsKnown() && hop.getParent().get(0).dimsKnown() + && hop.getDim2() > 128 * hop.getParent().get(0).getDim1() + && hop.getDim2() > 128 * hop.getParent().get(0).getDim2() + && HopRewriteUtils.isMatrixMultiply(hop.getParent().get(0)) + && isFuseSkinnyMatrixMult(hop.getParent().get(0)) + && ((hop.getParent().get(0).getInput().get(0) == hop && + HopRewriteUtils.containsInput(hop, hop.getParent().get(0).getInput().get(1))) + ||(hop.getParent().get(0).getInput().get(1) == hop && + HopRewriteUtils.containsInput(hop, hop.getParent().get(0).getInput().get(0)))); + } @Override public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) { @@ -214,7 +231,7 @@ public class TemplateRow extends TemplateBase } private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) - { + { //memoization for common subexpression elimination and to avoid redundant work if( tmp.containsKey(hop.getHopID()) ) return; @@ -240,15 +257,20 @@ public class TemplateRow extends TemplateBase if( ((AggUnaryOp)hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG) ) { if(hop.getInput().get(0).getDim2()==1) out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R); - else if( HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM) + else if( HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM) && HopRewriteUtils.isBinaryMatrixScalar(hop.getInput().get(0), OpOp2.NOTEQUAL, 0) && cdata1 instanceof CNodeBinary ) { out = new CNodeUnary(cdata1.getInput().get(0), UnaryType.ROW_COUNTNNZS); } + else if( HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM) + && HopRewriteUtils.isBinaryMatrixScalar(hop.getInput().get(0), OpOp2.POW, 2) + && cdata1 instanceof CNodeBinary ) { + out = new CNodeUnary(cdata1.getInput().get(0), UnaryType.ROW_SUMSQS); + } else { String opcode = "ROW_"+((AggUnaryOp)hop).getOp().name().toUpperCase()+"S"; out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode)); - if( cdata1 instanceof CNodeData && inHops2.isEmpty() ) + if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") ) inHops2.put("X", hop.getInput().get(0)); } } @@ -275,17 +297,22 @@ public class TemplateRow extends TemplateBase //correct input under transpose cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals); inHops.remove(hop.getInput().get(0)); - inHops.add(hop.getInput().get(0).getInput().get(0)); + if( cdata1 instanceof CNodeData ) + inHops.add(hop.getInput().get(0).getInput().get(0)); //note: vectorMultAdd applicable to vector-scalar, and vector-vector if( hop.getInput().get(1).getDim2() == 1 ) out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD); else { out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD); - if( !inHops2.containsKey("B1") ) + if( !inHops2.containsKey("B1") ) { //incl modification of X for consistency + if( cdata1 instanceof CNodeData ) + inHops2.put("X", hop.getInput().get(0).getInput().get(0)); inHops2.put("B1", hop.getInput().get(1)); + } } - inHops2.put("X", hop.getInput().get(0).getInput().get(0)); + if( !inHops2.containsKey("X") ) + inHops2.put("X", hop.getInput().get(0).getInput().get(0)); } else { @@ -321,7 +348,7 @@ public class TemplateRow extends TemplateBase if( HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY) ) { String opname = "VECT_"+((UnaryOp)hop).getOp().name(); out = new CNodeUnary(cdata1, UnaryType.valueOf(opname)); - if( cdata1 instanceof CNodeData && inHops2.isEmpty() ) + if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") ) inHops2.put("X", hop.getInput().get(0)); } else @@ -350,7 +377,7 @@ public class TemplateRow extends TemplateBase cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1)); } out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND); - if( cdata1 instanceof CNodeData ) + if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") ) inHops2.put("X", hop.getInput().get(0)); } else if(hop instanceof BinaryOp) @@ -379,7 +406,7 @@ public class TemplateRow extends TemplateBase cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname)); } - if( cdata1 instanceof CNodeData && inHops2.isEmpty() + if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") && !(cdata1.getDataType()==DataType.SCALAR) ) { inHops2.put("X", hop.getInput().get(0)); } http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java index 356c729..1d56f1c 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java @@ -118,6 +118,19 @@ public class LibSpoofPrimitives } } + public static void vectOuterMultAdd(double[] a, double[] b, double[] c, int ai, int[] bix, int bi, int ci, int blen, int len1, int len2) { + if( isFlipOuter(len1, len2) ) { + for( int i=bi; i<bi+blen; i++ ) { + final int cix = ci + bix[i] * len1; + LibMatrixMult.vectMultiplyAdd(b[i], a, c, ai, cix, len1); + } + } + else { + for( int i=0, cix=ci; i < len1; i++, cix+=len2 ) + LibMatrixMult.vectMultiplyAdd(a[ai+i], b, c, bix, bi, cix, blen); + } + } + public static void vectMultAdd(double[] a, double bval, double[] c, int bi, int ci, int len) { if( a == null || bval == 0 ) return; LibMatrixMult.vectMultiplyAdd(bval, a, c, bi, ci, len); @@ -257,6 +270,14 @@ public class LibSpoofPrimitives return vectSum(avals, ai, alen); } + public static double vectSumsq(double[] a, int ai, int len) { + return LibMatrixMult.dotProduct(a, a, ai, ai, len); + } + + public static double vectSumsq(double[] avals, int[] aix, int ai, int alen, int len) { + return LibMatrixMult.dotProduct(avals, avals, ai, ai, alen); + } + public static double vectMin(double[] a, int ai, int len) { double val = Double.MAX_VALUE; for( int i = ai; i < ai+len; i++ ) @@ -1837,12 +1858,18 @@ public class LibSpoofPrimitives * vectors of different sizes are interspersed. */ private static class VectorBuffer { + private static final int MAX_SIZE = 512*1024; //4MB private final double[][] _data; private int _pos; private int _len1; private int _len2; public VectorBuffer(int num, int len1, int len2) { + //best effort size restriction since large intermediates + //not necessarily used (num refers to the total number) + len1 = Math.min(len1, MAX_SIZE); + len2 = Math.min(len2, MAX_SIZE); + //pre-allocate ring buffer int lnum = (len2>0 && len1!=len2) ? 2*num : num; _data = new double[lnum][]; for( int i=0; i<num; i++ ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java index fdb208e..62a20b8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java @@ -20,17 +20,12 @@ package org.apache.sysml.runtime.instructions.gpu; import java.util.ArrayList; -import jcuda.Pointer; - import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.functionobjects.SwapIndex; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; -import org.apache.sysml.runtime.instructions.cp.ConvolutionCPInstruction; -import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig; -import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA; import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index 78305e3..5d2015f 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -69,6 +69,7 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME30 = TEST_NAME+"30"; //Mlogreg inner core, multi-class private static final String TEST_NAME31 = TEST_NAME+"31"; //MLogreg - matrix-vector cbind 0s generalized private static final String TEST_NAME32 = TEST_NAME+"32"; //X[, 1] - rowSums(X) + private static final String TEST_NAME33 = TEST_NAME+"33"; //Kmeans, inner loop private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -80,7 +81,7 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=32; i++) + for(int i=1; i<=33; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @@ -564,6 +565,21 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME32, false, ExecType.SPARK ); } + @Test + public void testCodegenRowAggRewrite33CP() { + testCodegenIntegration( TEST_NAME33, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg33CP() { + testCodegenIntegration( TEST_NAME33, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg33SP() { + testCodegenIntegration( TEST_NAME33, false, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/test/scripts/functions/codegen/rowAggPattern33.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern33.R b/src/test/scripts/functions/codegen/rowAggPattern33.R new file mode 100644 index 0000000..2d2490d --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern33.R @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = matrix(seq(1,6000)/6000, 600, 10, byrow=TRUE); +C = matrix(seq(1,40)/40, 4, 10, byrow=TRUE); + +D = -2 * (X %*% t(C)) + matrix(1,nrow(X),1)%*%t(rowSums (C ^ 2)); +P = (D <= (rowMins (D) %*% matrix(1, 1, ncol(D)))); +P = P / rowSums (P); +P_denom = colSums (P); +R = (t(P) %*% X) / P_denom%*%matrix(1,1,ncol(X)); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/d907efc1/src/test/scripts/functions/codegen/rowAggPattern33.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern33.dml b/src/test/scripts/functions/codegen/rowAggPattern33.dml new file mode 100644 index 0000000..54c277c --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern33.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X = matrix(seq(1,6000)/6000, 600, 10); +C = matrix(seq(1,40)/40, 4, 10); +while(FALSE){} + +D = -2 * (X %*% t(C)) + t(rowSums (C ^ 2)); +P = D <= rowMins(D); +P = P / rowSums (P); +P_denom = colSums (P); +R = (t(P) %*% X) / t(P_denom); + +write(R, $1)