http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java index bc6ef5b..8764a80 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java @@ -20,470 +20,191 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; +import java.util.LinkedList; -import org.apache.sysml.api.DMLException; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.Hop.OpOp2; -import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; -import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; -import org.apache.sysml.hops.codegen.cplan.CNodeUnary; -import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.rewrite.HopRewriteUtils; -import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; import org.apache.sysml.runtime.matrix.data.Pair; public class OuterProductTpl extends BaseTpl { public OuterProductTpl() { - super(TemplateType.OuterProductTpl); + super(TemplateType.OuterProdTpl); } - - private List<OpOp2> sparseDrivers = new ArrayList<OpOp2>(Arrays.asList(OpOp2.MULT, OpOp2.DIV)); - private OutProdType _outerProductType = null; - private boolean _transposeOutput = false; - private boolean _transposeInput = false; - + + @Override + public boolean open(Hop hop) { + //open on outer product like matrix mult (output larger than common dim) + return HopRewriteUtils.isOuterProductLikeMM(hop) + && hop.getDim1()>256 && hop.getDim2() > 256; + } + @Override - public boolean openTpl(Hop hop) { - // outerproduct ( output dimensions is greater than the common dimension) - return ( hop instanceof AggBinaryOp && ((AggBinaryOp)hop).isMatrixMultiply() && hop.dimsKnown() - && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() - && (hop.getDim1() > hop.getInput().get(0).getDim2() && hop.getDim2() > hop.getInput().get(1).getDim1()) ); + public boolean fuse(Hop hop, Hop input) { + return !isClosed() + &&((hop instanceof UnaryOp && TemplateUtils.isOperationSupported(hop)) + || (hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop)) + || HopRewriteUtils.isTransposeOperation(hop) + || (hop instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(hop)) + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.RowCol)); } @Override - public boolean findTplBoundaries(Hop h, CplanRegister cplanRegister) { - _endHop = h;//outerProduct tpl starts with endHop - HashMap<String,Hop> uniqueMatrixInputs = new HashMap<String,Hop>(); - uniqueMatrixInputs.put("U", h.getInput().get(0)); - if( h.getInput().get(1) instanceof ReorgOp && ((ReorgOp)h.getInput().get(1)).getOp() == ReOrgOp.TRANSPOSE ) - uniqueMatrixInputs.put("V", h.getInput().get(1).getInput().get(0)); + public boolean merge(Hop hop, Hop input) { + return !isClosed() && + (TemplateUtils.isBinaryMatrixRowVector(hop) + || HopRewriteUtils.isBinaryMatrixScalarOperation(hop)); + } + + @Override + public CloseType close(Hop hop) { + // close on second matrix multiply (after open) or unary aggregate + if( hop instanceof AggUnaryOp && HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) ) + return CloseType.CLOSED_INVALID; + else if( (hop instanceof AggUnaryOp) + || (hop instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(hop) + && !HopRewriteUtils.isTransposeOperation(hop.getParent().get(0))) + || (HopRewriteUtils.isTransposeOperation(hop) && hop.getInput().get(0) instanceof AggBinaryOp + && !HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) )) + return CloseType.CLOSED_VALID; else - { - _transposeInput = true; // we need to transpose V to be tall and skinny - uniqueMatrixInputs.put("V", h.getInput().get(1)); - } - rfindOuterProduct(_endHop, _endHop, uniqueMatrixInputs, h.getDim1(), h.getDim2(), new HashSet<Long>()); + return CloseType.OPEN; + } + + public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) + { + //recursively process required cplan output + HashSet<Hop> inHops = new HashSet<Hop>(); + HashMap<String,Hop> inHops2 = new HashMap<String, Hop>(); + HashMap<Long, CNode> tmp = new HashMap<Long, CNode>(); + hop.resetVisitStatus(); + rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals); + hop.resetVisitStatus(); - if(uniqueMatrixInputs.size() == 3 && _initialHop != null && _initialHop != _endHop ) //sanity check - { - //check if added matrices can be inferred from input matrices for example (X!=0) or abs(X) are not different from X - Hop commonChild = null; - if(! _adddedMatrices.isEmpty() ) { - //if addedMatrices does not have a common child with input X then do not compile - commonChild = TemplateUtils.commonChild(_adddedMatrices,uniqueMatrixInputs.get("X")); - if(commonChild == null ) // there are multiple matrices involved other than X - return false; - } - if(commonChild != null) { - _matrixInputs.add(commonChild); //add common child as the major input matrix - _adddedMatrices.add(uniqueMatrixInputs.get("X")); // put unique matrix as one of the additional matrices that is a chain of cell wise operations for the input matrix - } + //reorder inputs (ensure matrix is first input) + Hop X = inHops2.get("_X"); + Hop U = inHops2.get("_U"); + Hop V = inHops2.get("_V"); + LinkedList<Hop> sinHops = new LinkedList<Hop>(inHops); + sinHops.remove(V); sinHops.addFirst(V); + sinHops.remove(U); sinHops.addFirst(U); + sinHops.remove(X); sinHops.addFirst(X); + + //construct template node + ArrayList<CNode> inputs = new ArrayList<CNode>(); + for( Hop in : sinHops ) + if( in != null ) + inputs.add(tmp.get(in.getHopID())); + + CNode output = tmp.get(hop.getHopID()); + CNodeOuterProduct tpl = new CNodeOuterProduct(inputs, output); + tpl.setOutProdType(TemplateUtils.getOuterProductType(X, U, V, hop)); + tpl.setTransposeOutput(!HopRewriteUtils.isTransposeOperation(hop) + && tpl.getOutProdType()==OutProdType.LEFT_OUTER_PRODUCT); + + return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); + } + + private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) + { + //recursively process required childs + MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OuterProdTpl); + for( int i=0; i<hop.getInput().size(); i++ ) { + Hop c = hop.getInput().get(i); + if( me.isPlanRef(i) ) + rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals); else { - _matrixInputs.add(uniqueMatrixInputs.get("X")); //major matrix is the sparse driver + CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); + tmp.put(c.getHopID(), cdata); + inHops.add(c); } - _matrixInputs.add(uniqueMatrixInputs.get("U")); + } + + //construct cnode for current hop + CNode out = null; + if(hop instanceof UnaryOp) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + String primitiveOpName = ((UnaryOp)hop).getOp().toString(); + out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); + } + else if(hop instanceof BinaryOp) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); + String primitiveOpName = ((BinaryOp)hop).getOp().toString(); - if(_transposeInput) { - ReorgOp transposeV = HopRewriteUtils.createTranspose(uniqueMatrixInputs.get("V")); - //ReorgOp transposeV = new ReorgOp("", uniqueMatrixInputs.get("V").getDataType(), uniqueMatrixInputs.get("V").getValueType(), ReOrgOp.TRANSPOSE, uniqueMatrixInputs.get("V")); - _matrixInputs.add(transposeV); + if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) ) { + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); } - else { - _matrixInputs.add(uniqueMatrixInputs.get("V")); + if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) ) { + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); } + if( HopRewriteUtils.isEqualSize(hop.getInput().get(0), hop.getInput().get(1)) + && hop.getInput().get(0) instanceof DataOp ) + inHops2.put("_X", hop.getInput().get(0)); - //add also added matrices so that they can be interpreted as inputs - for(Hop addedMatrix : _adddedMatrices) - if(!_matrixInputs.contains(addedMatrix)) - _matrixInputs.add(addedMatrix); - - //add the children of _endHop ( this will handle the case for wdivmm right when I add the both t(V) and V as inputs - for (Hop hop: _endHop.getInput()) - _matrixInputs.add(hop); - - return true; + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); } - else - return false; - - } - private void rfindOuterProduct(Hop child, Hop h, HashMap<String,Hop> uniqueMatrixInputs, long outerProductDim1, long outerProductDim2, HashSet<Long> memo) - { - if(memo.contains(h.getHopID())) - return; - - if( ( h instanceof UnaryOp || h instanceof BinaryOp ) //unary operation or binary operation - && h.getDataType() == DataType.MATRIX // Output is a matrix - && h.getDim1() == outerProductDim1 && h.getDim2() == outerProductDim2 // output is the same size as the matrix - && TemplateUtils.isOperationSupported(h)) // operation is supported in codegen + else if(hop instanceof AggBinaryOp) { - if(h instanceof BinaryOp) + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); + + //handle tanspose in outer or final product + cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals); + cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals); + + //outerproduct U%*%t(V), see open + if( HopRewriteUtils.isOuterProductLikeMM(hop) ) { - - // find the other child rather than the one that called the parent - Hop otherChild = h.getInput().get(0) != child ? h.getInput().get(0) : h.getInput().get(1); - - //if scalar or vector then we fuse it similar to the way we fuse celltpl, - if(TemplateUtils.isVectorOrScalar(otherChild)) - { - _initialHop = h; - _outerProductType = OutProdType.CELLWISE_OUTER_PRODUCT; - - } - // other child is a matrix + //keep U and V for later reference + inHops2.put("_U", hop.getInput().get(0)); + if( HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)) ) + inHops2.put("_V", hop.getInput().get(1).getInput().get(0)); else - { - //if the binary operation is sparse safe (mult, div) - if(sparseDrivers.contains(((BinaryOp)h).getOp()) ) - { - if(!uniqueMatrixInputs.containsKey("X")) - { - //extra sanity check - if(otherChild.getDim1() == outerProductDim1 && otherChild.getDim2() == outerProductDim2) { - uniqueMatrixInputs.put("X", otherChild); - _initialHop = h; - } - else { //matrix size does not match what is expected for X - return; - } - } - } - else { - _adddedMatrices.add(otherChild); - } - } + inHops2.put("_V", hop.getInput().get(1)); + + out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); } - } - - if( h instanceof AggBinaryOp && ((AggBinaryOp) h).isMatrixMultiply() && h != child) //make sure that the AggBinaryOp is not the same as the outerproduct that triggered this method - { - if(memo.contains(h.getInput().get(0).getHopID())) { // if current node is the parent for the left child then it is right matrix multiply - - if (h.getInput().get(1) == uniqueMatrixInputs.get("V") )//right operand is V - { - _initialHop = h; - _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT; - return; - } - //right operand is t(V) - else if(h.getInput().get(1) instanceof ReorgOp && ((ReorgOp)h.getInput().get(1)).getOp() == ReOrgOp.TRANSPOSE && h.getInput().get(1).getInput().get(0) == uniqueMatrixInputs.get("V") ) - { - //replace V with T(V) - uniqueMatrixInputs.put("V", h.getInput().get(1)); - _transposeInput = false; //no need to transpose Input - _initialHop = h; - _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT; - return; - } + //final left/right matrix mult, see close + else { + if( cdata1.getDataType().isScalar() ) + out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD); else - { - _initialHop = h.getInput().get(0); // set the child that was processed - return; - } - } - else {//left matrix multiply - - //left is T(U) - if (h.getInput().get(0) instanceof ReorgOp && ((ReorgOp)h.getInput().get(0)).getOp() == ReOrgOp.TRANSPOSE && h.getInput().get(0).getInput().get(0) == uniqueMatrixInputs.get("U") ) - { - _initialHop = h; - _outerProductType = OutProdType.LEFT_OUTER_PRODUCT; - //T(T(U) %*% ..) - for(Hop hParent : h.getParent()) - if(hParent instanceof ReorgOp && ((ReorgOp)hParent).getOp() == ReOrgOp.TRANSPOSE) { - _initialHop = hParent; // set the transpose hop - return; - } - _transposeOutput = true; - return; - } - else { - _initialHop = h.getInput().get(1); // set the child that was processed - return; - } + out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD); } } - - if( h instanceof AggUnaryOp && ((AggUnaryOp) h).getOp() == AggOp.SUM - && ((AggUnaryOp) h).getDirection() == Direction.RowCol) + else if( HopRewriteUtils.isTransposeOperation(hop) ) { - _initialHop = h; - _outerProductType = OutProdType.AGG_OUTER_PRODUCT; - return; + out = tmp.get(hop.getInput().get(0).getHopID()); } - - memo.add(h.getHopID()); - //process parents recursively - for( Hop parent : h.getParent()) - rfindOuterProduct(h, parent,uniqueMatrixInputs, outerProductDim1,outerProductDim2, memo); - } - - ////////////////Helper methods for finding boundaries - private OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out) - { - if (_outerProductType != null) - return _outerProductType; - - - //extra checks to infer type - if (out.getDataType() == DataType.SCALAR) // sum - { - _outerProductType = OutProdType.AGG_OUTER_PRODUCT; - } - else if( isDimsEqual(out,V) && out instanceof ReorgOp) // the second condition is added because sometimes V and U might be same dimensions if the dims of X are equal - { - _outerProductType = OutProdType.LEFT_OUTER_PRODUCT; - } - else if( isDimsEqual(out,U)) + else if( hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM + && ((AggUnaryOp)hop).getDirection() == Direction.RowCol ) { - _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT; - } - else if ( isDimsEqual(out,X) ) - { - _outerProductType = OutProdType.CELLWISE_OUTER_PRODUCT; + out = tmp.get(hop.getInput().get(0).getHopID()); } - return _outerProductType; - } - - private static boolean isDimsEqual(Hop hop1, Hop hop2) - { - if(hop1.getDim1() == hop2.getDim1() && hop1.getDim2() == hop2.getDim2()) - return true; - return false; - } - - @Override - public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals) throws DMLException { - - //re-assign the dimensions of inputs to match the generated code dimensions - - //matrix X is a scalar in the generated code - _initialCnodes.add(new CNodeData(_matrixInputs.get(0), 1,1,DataType.SCALAR)); - - //matrix V - _initialCnodes.add(new CNodeData(_matrixInputs.get(1), 1,(int)_matrixInputs.get(1).getDim2(), DataType.MATRIX)); - - //matrix V - _initialCnodes.add(new CNodeData(_matrixInputs.get(2), 1,(int)_matrixInputs.get(2).getDim2(),DataType.MATRIX)); - - rConstructOuterProdCplan(_initialHop, _initialHop, new HashSet<Long>(), compileLiterals); - return _cpplans; - } - - private void rConstructOuterProdCplan(Hop root, Hop hop, HashSet<Long> memo, boolean compileLiterals) throws DMLException - { - if( memo.contains(hop.getHopID()) ) - return; - //process childs recursively - for( Hop c : hop.getInput() ) - rConstructOuterProdCplan(root, c, memo, compileLiterals); - - //organize the main inputs - Hop X, U, V; - X = _matrixInputs.get(0); - U = _matrixInputs.get(1); - V = _matrixInputs.get(2); - if(hop==_endHop) - _endHopReached = true; - - // first hop to enter here should be _endHop - if(TemplateUtils.inputsAreGenerated(hop,_matrixInputs,_cpplans) && _endHopReached) // if direct children are DataGenOps, literals, or already in the cpplans then we are ready to generate code - { - CNodeOuterProduct outerProdTmpl = null; - - //Fetch operands - CNode out = null; - ArrayList<CNode> addedCNodes = new ArrayList<CNode>(); - ArrayList<Hop> addedHops = new ArrayList<Hop>(); - ArrayList<CNode> cnodeData = TemplateUtils.fetchOperands(hop, _cpplans, addedCNodes, addedHops, _initialCnodes, compileLiterals); - - //if operands are scalar or independent from X - boolean independentOperands = hop != root && (hop.getDataType() == DataType.SCALAR || TemplateUtils.isOperandsIndependent(cnodeData, addedHops, new String[]{_matrixInputs.get(0).getName(),_matrixInputs.get(1).getName(),_matrixInputs.get(2).getName()})); - if(!independentOperands) - { - if(hop instanceof UnaryOp) - { - CNode cdata1 = cnodeData.get(0); - - //Primitive Operation has the same name as Hop Type OpOp1 - String primitiveOpName = ((UnaryOp)hop).getOp().toString(); - out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); - } - else if(hop instanceof BinaryOp) - { - CNode cdata1 = cnodeData.get(0); - CNode cdata2 = cnodeData.get(1); - - //Primitive Operation has the same name as Hop Type OpOp2 - String primitiveOpName = ((BinaryOp)hop).getOp().toString(); - - if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) ) - { - //second argument is always the vector - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - //out = new CNodeBinary(tmp, cdata2, BinType.valueOf(primitiveOpName)); - } - //cdata1 is a matrix - else if ( (cdata1.getNumRows() > 1 && cdata1.getNumCols() > 1) ) - { - CellTpl cellTpl = new CellTpl(); - cdata1 = cellTpl.fuseCellWise(hop.getInput().get(0), _matrixInputs.get(0), compileLiterals); // second argument is always matrix X - if (cdata1 == null) - return; - } - //cdata2 is vector - //else if( cdata2 instanceof CNodeData && (((CNodeData)cdata2).getNumRows() > 1 && ((CNodeData)cdata2).getNumCols() == 1) || ( ((CNodeData)cdata2).getNumRows() == 1 && ((CNodeData)cdata2).getNumCols() > 1 )) - if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) ) - { - cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); - //out = new CNodeBinary(cdata1, tmp, BinType.valueOf(primitiveOpName)); - } - //cdata2 is a matrix - else if ( (cdata2.getNumRows() > 1 && cdata2.getNumCols() > 1) ) - { - CellTpl cellTpl = new CellTpl(); - cdata2 = cellTpl.fuseCellWise(hop.getInput().get(1), _matrixInputs.get(0), compileLiterals); // second argument is always matrix X - if (cdata2 == null) - return; - } - out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); - } - else if(hop instanceof AggBinaryOp) - { - CNode cdata1 = cnodeData.get(0); - CNode cdata2 = cnodeData.get(1); // remember that we already fetched what is under transpose - - //outerproduct U%*%t(V) then we should have passsed in V as the input - if(hop.getInput().get(0) == U && hop.getInput().get(1) instanceof ReorgOp && hop.getInput().get(1).getInput().get(0) == V) - { - //re-assign cdata2 to read V instead of t(V) - cdata2 = _initialCnodes.get(2); // the initialCNodes holds V - out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); - } - - //outerproduct U%*%V then we should have passsed in trnasposeV as the input - else if(hop.getInput().get(0) == U && V instanceof ReorgOp && V.getInput().get(0)== hop.getInput().get(1)) - { - //re-assign cdata2 to read t(V) instead of V - cdata2 = _initialCnodes.get(2); // the initialCNodes holds transpose of V - out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); - } - //outerproduct U%*%V but not right wdivmm so we did not pass T(V) - else if(hop.getInput().get(0) == U && hop.getInput().get(1) == V ) - { - //re-assign cdata2 to read t(V) instead of V - cdata2 = _initialCnodes.get(2); // the initialCNodes holds transpose of V - out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); - } - - //left outerproduct (i.e., left operand is T(U) ) - else if(hop.getInput().get(0) instanceof ReorgOp && hop.getInput().get(0).getInput().get(0) == U) - { - //scalar is cdata2 - out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD); - } - - //right outerproduct (i.e., right operand is V ) - else if(hop.getInput().get(1) != U && hop.getInput().get(1) == V) - { - cdata2 = _initialCnodes.get(2); - out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD); - } - - //right outerproduct (i.e., right operand is t(V) ) - else if(hop.getInput().get(1) instanceof ReorgOp && hop.getInput().get(1).getInput().get(0) == V) - { - cdata2 = _initialCnodes.get(2); - out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD); - } - } - else if ( hop instanceof ReorgOp && ((ReorgOp)hop).getOp() == ReOrgOp.TRANSPOSE && root == hop) // if transpose wire the oinput in T( T(U ...) - { - out = cnodeData.get(0); - } - else if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM && root == hop - && ((AggUnaryOp)hop).getDirection() == Direction.RowCol ) - { - out = cnodeData.get(0); - } - } - // wire output to the template - if(out != null || independentOperands) - { - if(_cpplans.isEmpty()) - { - //first initialization has to have the first variable as input - ArrayList<CNode> initialInputs = new ArrayList<CNode>(); - - if(independentOperands) // pass the hop itself as an input instead of its children - { - CNode c = new CNodeData(hop); - initialInputs.addAll(_initialCnodes); - initialInputs.add(c); - outerProdTmpl = new CNodeOuterProduct(initialInputs, c); - outerProdTmpl.setOutProdType(getOuterProductType(X, U, V, root)); - outerProdTmpl.setTransposeOutput(_transposeOutput); - _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(new Hop[] {X,U,V,hop} ,outerProdTmpl)); - } - else - { - initialInputs.addAll(_initialCnodes); - initialInputs.addAll(cnodeData); - outerProdTmpl = new CNodeOuterProduct(initialInputs, out); - outerProdTmpl.setOutProdType(getOuterProductType(X, U, V, root)); - outerProdTmpl.setTransposeOutput(_transposeOutput); - - Hop[] hopArray = new Hop[addedHops.size()+3]; - hopArray[0] = X; - hopArray[1] = U; - hopArray[2] = V; - - System.arraycopy( addedHops.toArray(), 0, hopArray, 3, addedHops.size()); - - _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(hopArray,outerProdTmpl)); - } - } - else - { - if(independentOperands) - { - CNode c = new CNodeData(hop); - //clear Operands - addedCNodes.clear(); - addedHops.clear(); - - //added the current hop as the input - addedCNodes.add(c); - addedHops.add(hop); - out = c; - } - - //wire the output to existing or new template - TemplateUtils.setOutputToExistingTemplate(hop, out, _cpplans, addedCNodes, addedHops); - } - } - memo.add(hop.getHopID()); - } + tmp.put(hop.getHopID(), out); } }
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java index 367cca3..9fa5efd 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java @@ -20,31 +20,27 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashMap; +import java.util.LinkedList; -import org.apache.sysml.api.DMLException; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; -import org.apache.sysml.hops.DataGenOp; -import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.LiteralOp; -import org.apache.sysml.hops.ReorgOp; -import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; import org.apache.sysml.hops.codegen.cplan.CNodeData; -import org.apache.sysml.hops.codegen.cplan.CNodeRowAggVector; +import org.apache.sysml.hops.codegen.cplan.CNodeRowAgg; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp2; -import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.matrix.data.Pair; @@ -55,266 +51,155 @@ public class RowAggTpl extends BaseTpl { } @Override - public boolean openTpl(Hop hop) { - if ( (hop instanceof AggBinaryOp || hop instanceof AggUnaryOp) // An aggregate operation - && ( (hop.getDim1()==1 && hop.getDim2()!=1) || (hop.getDim1()!=1 && hop.getDim2()==1) ) )// the output is a vector - return true; - return false; + public boolean open(Hop hop) { + //any unary or binary aggregate operation with a vector output, but exclude binary aggregate + //with transposed input to avoid counter-productive fusion + return ( ((hop instanceof AggBinaryOp && hop.getInput().get(1).getDim1()>1 + && !HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM )) + && ( (hop.getDim1()==1 && hop.getDim2()!=1) || (hop.getDim1()!=1 && hop.getDim2()==1) ) ); } @Override - public boolean findTplBoundaries(Hop initialHop, CplanRegister cplanRegister) { - _initialHop = initialHop; - if(initialHop instanceof AggBinaryOp) { - // for simplicity we assume that the first operand should be t(X) however, it could be later on W.T(X) - if(initialHop.getInput().get(0) instanceof ReorgOp && ((ReorgOp)initialHop.getInput().get(0)).getOp()== ReOrgOp.TRANSPOSE ) - _matrixInputs.add(initialHop.getInput().get(0).getInput().get(0)); //add what is under the transpose - else - return false; - } - rFindRowAggPattern(initialHop, new HashSet<Long>()); - - if(cplanRegister.containsHop(TemplateType.RowAggTpl, initialHop.getHopID())) - return false; - - return (_endHop != null); + public boolean fuse(Hop hop, Hop input) { + return !isClosed() && + ( (hop instanceof BinaryOp && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) + || HopRewriteUtils.isBinaryMatrixScalarOperation(hop))) + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.Col) + || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); } - - - private void rFindRowAggPattern(Hop h, HashSet<Long> memo) - { - if(memo.contains(h.getHopID()) || h.getDataType() == DataType.SCALAR - || h instanceof DataOp || h instanceof DataGenOp || h instanceof LiteralOp) { - return; - } + + @Override + public boolean merge(Hop hop, Hop input) { + //merge rowagg tpl with cell tpl if input is a vector + return !isClosed() && + (hop instanceof BinaryOp && input.getDim2()==1 ); //matrix-scalar/vector-vector ops ) + } + + @Override + public CloseType close(Hop hop) { + //close on column aggregate (e.g., colSums, t(X)%*%y) + if( hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.Col + || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) ) + return CloseType.CLOSED_VALID; + else + return CloseType.OPEN; + } + + @Override + public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) { + //recursively process required cplan output + HashSet<Hop> inHops = new HashSet<Hop>(); + HashMap<Long, CNode> tmp = new HashMap<Long, CNode>(); + hop.resetVisitStatus(); + rConstructCplan(hop, memo, tmp, inHops, compileLiterals); + hop.resetVisitStatus(); - boolean continueTraversing = false; - if (h instanceof AggBinaryOp) - { - if(h != _initialHop) { - //T(X) % ..... X %*% v ,check that X is the same as what we saw previously under transpose - if( h.getInput().get(0).equals(_matrixInputs.get(0)) && TemplateUtils.isVector(h.getInput().get(1)) ) { - _endHop = h; - } + //reorder inputs (ensure matrix is first input) + LinkedList<Hop> sinHops = new LinkedList<Hop>(inHops); + for( Hop h : inHops ) + if( h.getDataType().isMatrix() && !TemplateUtils.isVector(h) ) { + sinHops.remove(h); + sinHops.addFirst(h); } + + //construct template node + ArrayList<CNode> inputs = new ArrayList<CNode>(); + for( Hop in : sinHops ) + inputs.add(tmp.get(in.getHopID())); + CNode output = tmp.get(hop.getHopID()); + CNodeRowAgg tpl = new CNodeRowAgg(inputs, output); + + // return cplan instance + return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); + } + + private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) + { + //recursively process required childs + MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.RowAggTpl); + for( int i=0; i<hop.getInput().size(); i++ ) { + Hop c = hop.getInput().get(i); + if( me.isPlanRef(i) ) + rConstructCplan(c, memo, tmp, inHops, compileLiterals); else { - continueTraversing = true; + CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); + tmp.put(c.getHopID(), cdata); + inHops.add(c); } } - // if initial hop is colSums continue - else if(h instanceof AggUnaryOp && (((AggUnaryOp)_initialHop).getDirection() == Direction.Col && ((AggUnaryOp)_initialHop).getOp() == AggOp.SUM ) && h == _initialHop) + + //construct cnode for current hop + CNode out = null; + if(hop instanceof AggUnaryOp) { - continueTraversing=true; + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + if( ((AggUnaryOp)hop).getDirection() == Direction.Row && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) { + if(hop.getInput().get(0).getDim2()==1) + out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R); + else + out = new CNodeUnary(cdata1, UnaryType.ROW_SUMS); + } + else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) { + //vector div add without temporary copy + if(cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType()==BinType.VECT_DIV_SCALAR) + out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), BinType.VECT_DIV_ADD); + else + out = cdata1; + } } - //rowSums(X) - else if(h instanceof AggUnaryOp && ((AggUnaryOp)h).getDirection() == Direction.Row && ((AggUnaryOp)h).getOp() == AggOp.SUM ) + else if(hop instanceof AggBinaryOp) { - // check if root pattern is colsums - if((((AggUnaryOp)_initialHop).getDirection() == Direction.Col && ((AggUnaryOp)_initialHop).getOp() == AggOp.SUM )) + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); + + if( HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) ) { + //correct input under transpose + cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals); + inHops.remove(hop.getInput().get(0)); + inHops.add(hop.getInput().get(0).getInput().get(0)); - //TODO Now the pattern is limited to finding rowSums - _matrixInputs.add(h.getInput().get(0)); - _endHop = h; + out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD); + } + else + { + if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1) + out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), + (cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT); + else + out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); } } - // unary operation || binary operation with first input as a matrix || binary operation with second input as a matrix - else if( ( h instanceof UnaryOp || (h instanceof BinaryOp && h.getInput().get(0).getDataType() == DataType.MATRIX && TemplateUtils.isVectorOrScalar(h.getInput().get(1))) || (h instanceof BinaryOp && TemplateUtils.isVectorOrScalar(h.getInput().get(0)) && h.getInput().get(1).getDataType() == DataType.MATRIX) ) //unary operation or binary operaiton with one matrix and a scalar - && h.getDataType() == DataType.MATRIX // Output is a matrix - && TemplateUtils.isOperationSupported(h) ) //Operation is supported in codegen - { - continueTraversing = true; - } - - //check if we should continue traversing - if(!continueTraversing) - { - return; // stop traversing if conditions does not apply - } - else + else if(hop instanceof BinaryOp) { - //process childs recursively - for( Hop in : h.getInput() ) - rFindRowAggPattern(in,memo); - } - memo.add(h.getHopID()); - } - - @Override - public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals) - throws DMLException { - - //re-assign the dimensions of inputs to match the generated code dimensions - _initialCnodes.add(new CNodeData(_matrixInputs.get(0))); - - rConstructRowAggCplan(_initialHop,_initialHop,new HashSet<Long>(), compileLiterals); - return _cpplans; - } - - private void rConstructRowAggCplan(Hop root, Hop hop, HashSet<Long> memo, boolean compileLiterals) throws DMLException - { - if( memo.contains(hop.getHopID()) ) - return; - //process childs recursively - for( Hop c : hop.getInput() ) - rConstructRowAggCplan(root, c, memo, compileLiterals); - if(hop == _endHop) - _endHopReached = true; - - // first hop to enter here should be _endHop - if(TemplateUtils.inputsAreGenerated(hop,_matrixInputs,_cpplans) && _endHopReached) // if direct children are DataGenOps, literals, or already in the cpplans then we are ready to generate code - { - CNodeRowAggVector rowTmpl = null; - - //Fetch operands - CNode out = null; - ArrayList<CNode> addedCNodes = new ArrayList<CNode>(); - ArrayList<Hop> addedHops = new ArrayList<Hop>(); - ArrayList<CNode> cnodeData = TemplateUtils.fetchOperands(hop, _cpplans, addedCNodes, addedHops, _initialCnodes, compileLiterals); + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); - //if operands are scalar or independent from X - boolean independentOperands = hop.getDataType() == DataType.SCALAR - || TemplateUtils.isOperandsIndependent(cnodeData, addedHops, new String[] {_matrixInputs.get(0).getName()}); - - if(!independentOperands) + // if one input is a matrix then we need to do vector by scalar operations + if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 ) { - - if(hop instanceof AggUnaryOp) - { - CNode cdata1 = cnodeData.get(0); - //set the out cnode based on the operation - if( ((AggUnaryOp)hop).getDirection() == Direction.Row && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) //RowSums - { - if(hop.getInput().get(0).getDim2()==1) - out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R); - else - out = new CNodeUnary(cdata1, UnaryType.ROW_SUMS); - } - // if colsums is the root hop, wire the input to the out because colsums it is done automatically by the template - else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM && root == hop) - { - //vector div add without temporary copy - if(cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType()==BinType.VECT_DIV_SCALAR) - out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), BinType.VECT_DIV_ADD); - else - out = cdata1; - } - } - else if(hop instanceof AggBinaryOp) - { - //Fetch operands specific to the operation - CNode cdata1 = cnodeData.get(0); - CNode cdata2 = cnodeData.get(1); - - //choose the operation based on the transpose - if( hop.getInput().get(0) instanceof ReorgOp && ((ReorgOp)hop.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE ) - { - //fetch the data inside the transpose - //cdata1 = new CNodeData(hop.getInput().get(0).getInput().get(0).getName(), (int)hop.getInput().get(0).getInput().get(0).getDim1(), (int)hop.getInput().get(0).getInput().get(0).getDim2()); - out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD); - } - else - { - if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1) - out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), - (cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT); - else - out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); - } - } - else if(hop instanceof BinaryOp) - { - CNode cdata1 = cnodeData.get(0); - CNode cdata2 = cnodeData.get(1); - - // if one input is a matrix then we need to do vector by scalar operations - if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 ) - { - if (((BinaryOp)hop).getOp()== OpOp2.DIV) - //CNode generatedScalar = new CNodeData("1", 0, 0); // generate literal in order to rewrite the div to x * 1/y - //CNode outScalar = new CNodeBinary(generatedScalar, cdata2, BinType.SCALAR_DIVIDE); - //out = new CNodeBinary(outScalar, cdata1, BinType.VECT_MULT_ADD); - out = new CNodeBinary(cdata1, cdata2, BinType.VECT_DIV_SCALAR); - - } - else //one input is a vector/scalar other is a scalar - { - //Primitive Operation has the same name as Hop Type OpOp2 - String primitiveOpName = ((BinaryOp)hop).getOp().toString(); - - if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) ) - { - //second argument is always the vector - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - } - //cdata2 is vector - //else if( cdata2 instanceof CNodeData && (((CNodeData)cdata2).getNumRows() > 1 && ((CNodeData)cdata2).getNumCols() == 1) || ( ((CNodeData)cdata2).getNumRows() == 1 && ((CNodeData)cdata2).getNumCols() > 1 )) - if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) ) - { - cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); - //out = new CNodeBinary(cdata1, tmp, BinType.valueOf(primitiveOpName)); - } - out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); - } - - } - - if( out.getDataType().isMatrix() ) { - out.setNumRows(hop.getDim1()); - out.setNumCols(hop.getDim2()); - } + if (((BinaryOp)hop).getOp()== OpOp2.DIV) + out = new CNodeBinary(cdata1, cdata2, BinType.VECT_DIV_SCALAR); } - // wire output to the template - if(out != null || independentOperands) + else //one input is a vector/scalar other is a scalar { - if(_cpplans.isEmpty()) - { - //first initialization has to have the first variable as input - ArrayList<CNode> initialInputs = new ArrayList<CNode>(); - - if(independentOperands) // pass the hop itself as an input instead of its children - { - CNode c = new CNodeData(hop); - initialInputs.addAll(_initialCnodes); - initialInputs.add(c); - rowTmpl = new CNodeRowAggVector(initialInputs, c); - _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(new Hop[] {_matrixInputs.get(0),hop} ,rowTmpl)); - } - else - { - initialInputs.addAll(_initialCnodes); - initialInputs.addAll(cnodeData); - rowTmpl = new CNodeRowAggVector(initialInputs, out); - - //Hop[] hopArray = new Hop[hop.getInput().size()+1]; - Hop[] hopArray = new Hop[addedHops.size()+1]; - hopArray[0] = _matrixInputs.get(0); - - //System.arraycopy( hop.getInput().toArray(), 0, hopArray, 1, hop.getInput().size()); - System.arraycopy( addedHops.toArray(), 0, hopArray, 1, addedHops.size()); - - _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(hopArray,rowTmpl)); - } + String primitiveOpName = ((BinaryOp)hop).getOp().toString(); + if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) ) { + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); } - else - { - if(independentOperands) - { - CNode c = new CNodeData(hop); - //clear Operands - addedCNodes.clear(); - addedHops.clear(); - - //added the current hop as the input - addedCNodes.add(c); - addedHops.add(hop); - out = c; - } - //wire the output to existing or new template - TemplateUtils.setOutputToExistingTemplate(hop, out, _cpplans, addedCNodes, addedHops); + if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) ) { + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); } + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); } - memo.add(hop.getHopID()); - } + } + + if( out.getDataType().isMatrix() ) { + out.setNumRows(hop.getDim1()); + out.setNumCols(hop.getDim2()); + } + + tmp.put(hop.getHopID(), out); } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 6934e02..80216fd 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -24,178 +24,32 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashSet; -import java.util.Map.Entry; +import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; -import org.apache.sysml.hops.DataGenOp; -import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.LiteralOp; -import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; -import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; -import org.apache.sysml.hops.codegen.cplan.CNodeCell; import org.apache.sysml.hops.codegen.cplan.CNodeData; -import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; -import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.hops.codegen.template.BaseTpl.TemplateType; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; -import org.apache.sysml.runtime.matrix.data.Pair; +import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; import org.apache.sysml.runtime.util.UtilFunctions; public class TemplateUtils { - public static boolean inputsAreGenerated(Hop parent, ArrayList<Hop> inputs, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans) - { - if( parent instanceof DataOp || parent instanceof DataGenOp || parent instanceof LiteralOp || inputs.contains(parent) ) - return false; - - for(Hop hop : parent.getInput() ) - if(!inputs.contains(hop) && !(hop instanceof DataOp) && !(hop instanceof DataGenOp) && !(hop.getDataType()==DataType.SCALAR) && !isVector(hop) && !(cpplans.containsKey(hop.getHopID())) && !( hop instanceof ReorgOp && ((ReorgOp)hop).getOp() == ReOrgOp.TRANSPOSE && inputsAreGenerated(hop,inputs, cpplans) )) - return false; - return true; - } - - public static ArrayList<CNode> fetchOperands(Hop hop, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans, ArrayList<CNode> addedCNodes, ArrayList<Hop> addedHops, ArrayList<CNodeData> initialCNodes, boolean compileLiterals) - { - ArrayList<CNode> cnodeData = new ArrayList<CNode>(); - for (Hop h: hop.getInput()) - { - CNode cdata = null; - - //CNodeData already in template inputs - for(CNodeData c : initialCNodes) { - if( c.getHopID() == h.getHopID() ) { - cdata = c; - break; - } - } - - if(cdata != null) - { - cnodeData.add(cdata); - continue; - } - //hop already in the cplan - else if(cpplans.containsKey(h.getHopID())) - { - cdata = cpplans.get(h.getHopID()).getValue().getOutput(); - } - else if(h instanceof ReorgOp && ((ReorgOp)h).getOp()==ReOrgOp.TRANSPOSE ) - { - //fetch what is under the transpose - Hop in = h.getInput().get(0); - cdata = new CNodeData(in); - if(in instanceof DataOp || in instanceof DataGenOp ) { - addedCNodes.add(cdata); - addedHops.add(in); - } - } - else - { - //note: only compile literals if forced or integer literals (likely constants) - //to increase reuse potential on literal replacement during recompilation - cdata = new CNodeData(h); - cdata.setLiteral(h instanceof LiteralOp && (compileLiterals - || UtilFunctions.isIntegerNumber(((LiteralOp)h).getStringValue()))); - if( !cdata.isLiteral() ) { - addedCNodes.add(cdata); - addedHops.add(h); - } - } - - cnodeData.add(cdata); - } - return cnodeData; - } - - public static void setOutputToExistingTemplate(Hop hop, CNode out, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans, ArrayList<CNode> addedCNodes, ArrayList<Hop> addedHops) - { - //get the toplevel rowTemp - Entry<Long, Pair<Hop[],CNodeTpl>> cplan = null; - Iterator<Entry<Long, Pair<Hop[],CNodeTpl>>> iterator = cpplans.entrySet().iterator(); - while (iterator.hasNext()) - cplan = iterator.next(); - - CNodeTpl tmpl = cplan.getValue().getValue().clone(); - tmpl.setDataType(hop.getDataType()); - - if(tmpl instanceof CNodeOuterProduct) { - ((CNodeOuterProduct) tmpl).setOutProdType( ((CNodeOuterProduct)cplan.getValue().getValue()).getOutProdType()); - ((CNodeOuterProduct) tmpl).setTransposeOutput(((CNodeOuterProduct)cplan.getValue().getValue()).isTransposeOutput() ); - } - else if( tmpl instanceof CNodeCell ) { - ((CNodeCell)tmpl).setCellType(getCellType(hop)); - ((CNodeCell)tmpl).setMultipleConsumers(hop.getParent().size()>1); - } - - //add extra inputs - for(CNode c : addedCNodes) - tmpl.addInput(c); - - //modify addedHops if they exist - - Hop[] currentInputHops = cplan.getValue().getKey(); - for (Hop h : currentInputHops) - if (addedHops.contains(h)) - addedHops.remove(h); - - Hop[] extendedHopInputs = new Hop[cplan.getValue().getKey().length + addedHops.size()]; - System.arraycopy(cplan.getValue().getKey(), 0, extendedHopInputs, 0, cplan.getValue().getKey().length); - for(int j=addedHops.size(); j > 0; j--) - extendedHopInputs[extendedHopInputs.length-j] = addedHops.get(addedHops.size() - j); //append the added hops to the end of the array - - //set the template output and add it to the cpplans - Pair<Hop[],CNodeTpl> pair = new Pair<Hop[],CNodeTpl>(extendedHopInputs,tmpl); - pair.getValue().setOutput(out); - cpplans.put(hop.getHopID(), pair); - - } - - public static boolean isOperandsIndependent(ArrayList<CNode> cnodeData, ArrayList<Hop> addedHops, String[] varNames) - { - for(CNode c : cnodeData) { - // it is some variable inside the cplan // TODO needs to be modified because sometimes the varname is not null but the variable is in the cplan - if(c.getVarname() == null) - return false; - //if one of the operands is is any of the varnames // if one of the operands is T(X) this condition will apply as well because during fetch operands we fetch what is inside transpose - for(String varName : varNames) - if(c.getVarname().equals(varName)) - return false; - } - return true; - } - - public static Entry<Long, Pair<Hop[],CNodeTpl>> getTopLevelCpplan(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) - { - Entry<Long, Pair<Hop[],CNodeTpl>> ret = null; - - //get last entry (most fused operators) or special handling - boolean hasExp = false; - for( Entry<Long, Pair<Hop[],CNodeTpl>> e : cplans.entrySet() ) - { - ret = e; //keep last seen entry - - //special handling overlapping fused operators with exp - hasExp |= (ret.getValue().getValue().getOutput() instanceof CNodeUnary - && ((CNodeUnary)ret.getValue().getValue().getOutput()).getType()==UnaryType.EXP); - - if( hasExp && ret.getValue().getValue() instanceof CNodeCell - && ((CNodeCell)ret.getValue().getValue()).hasMultipleConsumers() ) - break; - } - - return ret; - } + public static final BaseTpl[] TEMPLATES = new BaseTpl[]{new RowAggTpl(), new CellTpl(), new OuterProductTpl()}; public static boolean isVector(Hop hop) { return (hop.getDataType() == DataType.MATRIX @@ -230,6 +84,16 @@ public class TemplateUtils && left.getDataType().isMatrix() && right.getDataType().isMatrix() && left.getDim1() > right.getDim1(); } + + public static boolean isBinaryMatrixColVector(Hop hop) { + if( !(hop instanceof BinaryOp) ) + return false; + Hop left = hop.getInput().get(0); + Hop right = hop.getInput().get(1); + return left.dimsKnown() && right.dimsKnown() + && left.getDataType().isMatrix() && right.getDataType().isMatrix() + && left.getDim2() > right.getDim2(); + } public static boolean isOperationSupported(Hop h) { if(h instanceof UnaryOp) @@ -307,16 +171,81 @@ public class TemplateUtils ret[pos++] = c; return ret; } + + public static BaseTpl createTemplate(TemplateType type) { + return createTemplate(type, false); + } - private static CellType getCellType(Hop hop) { + public static BaseTpl createTemplate(TemplateType type, boolean closed) { + BaseTpl tpl = null; + switch( type ) { + case CellTpl: tpl = new CellTpl(); break; + case RowAggTpl: tpl = new RowAggTpl(); break; + case OuterProdTpl: tpl = new OuterProductTpl(); break; + } + tpl._closed = closed; + return tpl; + } + + public static CellType getCellType(Hop hop) { return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM) ? ((((AggUnaryOp) hop).getDirection() == Direction.RowCol) ? CellType.FULL_AGG : CellType.ROW_AGG) : CellType.NO_AGG; } + public static OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out) { + if( out.getDataType() == DataType.SCALAR ) + return OutProdType.AGG_OUTER_PRODUCT; + else if( (out instanceof AggBinaryOp && (out.getInput().get(0) == U + || HopRewriteUtils.isTransposeOperation(out.getInput().get(0)) + && out.getInput().get(0).getInput().get(0) == U)) + || HopRewriteUtils.isTransposeOperation(out) ) + return OutProdType.LEFT_OUTER_PRODUCT; + else if( out instanceof AggBinaryOp && (out.getInput().get(1) == V + || HopRewriteUtils.isTransposeOperation(out.getInput().get(1)) + && out.getInput().get(1).getInput().get(0) == V ) ) + return OutProdType.RIGHT_OUTER_PRODUCT; + else if( out instanceof BinaryOp && HopRewriteUtils.isEqualSize(out.getInput().get(0), out.getInput().get(1)) ) + return OutProdType.CELLWISE_OUTER_PRODUCT; + + //should never come here + throw new RuntimeException("Undefined outer product type"); + } + public static boolean isLookup(CNode node) { return (node instanceof CNodeUnary && (((CNodeUnary)node).getType()==UnaryType.LOOKUP_R || ((CNodeUnary)node).getType()==UnaryType.LOOKUP_RC)); } + + public static CNodeData createCNodeData(Hop hop, boolean compileLiterals) { + CNodeData cdata = new CNodeData(hop); + cdata.setLiteral(hop instanceof LiteralOp && (compileLiterals + || UtilFunctions.isIntegerNumber(((LiteralOp)hop).getStringValue()))); + return cdata; + } + + public static CNode skipTranspose(CNode cdataOrig, Hop hop, HashMap<Long, CNode> tmp, boolean compileLiterals) { + if( HopRewriteUtils.isTransposeOperation(hop) ) { + CNode cdata = tmp.get(hop.getInput().get(0).getHopID()); + if( cdata == null ) { //never accessed + cdata = TemplateUtils.createCNodeData(hop.getInput().get(0), compileLiterals); + tmp.put(hop.getInput().get(0).getHopID(), cdata); + } + tmp.put(hop.getHopID(), cdata); + return cdata; + } + else { + return cdataOrig; + } + } + + public static boolean hasTransposeParentUnderOuterProduct(Hop hop) { + for( Hop p : hop.getParent() ) + if( HopRewriteUtils.isTransposeOperation(p) ) + for( Hop p2 : p.getParent() ) + if( HopRewriteUtils.isOuterProductLikeMM(p2) ) + return true; + return false; + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index d6ad037..a384ac1 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -682,7 +682,8 @@ public class HopRewriteUtils } public static boolean isOuterProductLikeMM( Hop hop ) { - return isMatrixMultiply(hop) + return isMatrixMultiply(hop) && hop.dimsKnown() + && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(0).getDim1() > hop.getInput().get(0).getDim2() && hop.getInput().get(1).getDim1() < hop.getInput().get(1).getDim2(); } @@ -759,6 +760,19 @@ public class HopRewriteUtils ||(hop.getInput().get(1).getDataType().isMatrix() && hop.getInput().get(0).getDataType().isScalar())); } + public static boolean isBinaryMatrixMatrixOperation(Hop hop) { + return hop instanceof BinaryOp + && hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isMatrix() + && hop.getInput().get(0).dimsKnown() && hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 + && hop.getInput().get(1).dimsKnown() && hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1; + } + + public static boolean isBinaryMatrixColVectorOperation(Hop hop) { + return hop instanceof BinaryOp + && hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isMatrix() + && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(1).getDim2() == 1; + } + public static boolean isUnary(Hop hop, OpOp1 type) { return hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==type; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java index 005b3ca..2ff8772 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java @@ -138,7 +138,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl //input preparation double[][] b = prepInputMatrices(inputs); double[] scalars = prepInputScalars(scalarObjects); - + //core sequential execute final int m = inputs.get(0).getNumRows(); final int n = inputs.get(0).getNumColumns(); @@ -197,12 +197,12 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl //as the output might change with different row indices for( int i=rl; i<ru; i++ ) for( int j=0; j<n; j++ ) - kplus.execute2(kbuff, genexecDense( 0, b, scalars, n, m, i, j )); + kplus.execute2(kbuff, genexec( 0, b, scalars, n, m, i, j )); } else { //general case for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++ ) - kplus.execute2(kbuff, genexecDense( a[ix], b, scalars, n, m, i, j )); + kplus.execute2(kbuff, genexec( a[ix], b, scalars, n, m, i, j )); } return kbuff._sum; @@ -219,14 +219,14 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl //as the output might change with different row indices for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++ ) { - c[ix] = genexecDense( 0, b, scalars, n, m, i, j ); + c[ix] = genexec( 0, b, scalars, n, m, i, j ); lnnz += (c[ix]!=0) ? 1 : 0; } } else { //general case for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++ ) { - c[ix] = genexecDense( a[ix], b, scalars, n, m, i, j); + c[ix] = genexec( a[ix], b, scalars, n, m, i, j); lnnz += (c[ix]!=0) ? 1 : 0; } } @@ -242,7 +242,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl for( int i=rl; i<ru; i++ ) { kbuff.set(0, 0); for( int j=0; j<n; j++ ) - kplus.execute2(kbuff, genexecDense( 0, b, scalars, n, m, i, j )); + kplus.execute2(kbuff, genexec( 0, b, scalars, n, m, i, j )); c[i] = kbuff._sum; lnnz += (c[i]!=0) ? 1 : 0; } @@ -251,7 +251,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl for( int i=rl, ix=rl*n; i<ru; i++ ) { kbuff.set(0, 0); for( int j=0; j<n; j++, ix++ ) - kplus.execute2(kbuff, genexecDense( a[ix], b, scalars, n, m, i, j )); + kplus.execute2(kbuff, genexec( a[ix], b, scalars, n, m, i, j )); c[i] = kbuff._sum; lnnz += (c[i]!=0) ? 1 : 0; } @@ -267,7 +267,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); //TODO rework sparse safe test - double val = genexecDense( 0, b, scalars, n, m, 0, 0 ); + double val = genexec( 0, b, scalars, n, m, 0, 0 ); if(val == 0 && b.length==0) // sparse safe { @@ -278,7 +278,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl int alen = sblock.size(i); double[] avals = sblock.values(i); for( int j=apos; j<apos+alen; j++ ) { - kplus.execute2( kbuff, genexecDense(avals[j], b, scalars, n, m, i, j)); + kplus.execute2( kbuff, genexec(avals[j], b, scalars, n, m, i, j)); } } } @@ -288,7 +288,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl for(int i=rl; i<ru; i++) for(int j=0; j<n; j++) { double valij = (sblock != null) ? sblock.get(i, j) : 0; - kplus.execute2( kbuff, genexecDense(valij, b, scalars, n, m, i, j)); + kplus.execute2( kbuff, genexec(valij, b, scalars, n, m, i, j)); } } @@ -298,7 +298,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl private long executeSparse(SparseBlock sblock, double[][] b, double[] scalars, double[] c, int n, int m, int rl, int ru) { //TODO rework sparse safe test - double val0 = genexecDense( 0, b, scalars, n, m, 0, 0 ); + double val0 = genexec( 0, b, scalars, n, m, 0, 0 ); long lnnz = 0; if( _type == CellType.NO_AGG ) @@ -312,7 +312,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl int alen = sblock.size(i); double[] avals = sblock.values(i); for( int j=apos; j<apos+alen; j++ ) { - double val = genexecDense(avals[j], b, scalars, n, m, i, j); + double val = genexec(avals[j], b, scalars, n, m, i, j); c[i*n+sblock.indexes(i)[j]] = val; lnnz += (val!=0) ? 1 : 0; } @@ -324,7 +324,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl for(int i=rl, cix=rl*n; i<ru; i++, cix+=n) for(int j=0; j<n; j++) { double valij = (sblock != null) ? sblock.get(i, j) : 0; - c[cix+j] = genexecDense(valij, b, scalars, n, m, i, j); + c[cix+j] = genexec(valij, b, scalars, n, m, i, j); lnnz += (c[cix+j]!=0) ? 1 : 0; } } @@ -344,7 +344,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl int alen = sblock.size(i); double[] avals = sblock.values(i); for( int j=apos; j<apos+alen; j++ ) { - kplus.execute2(kbuff, genexecDense(avals[j], b, scalars, n, m, i, j)); + kplus.execute2(kbuff, genexec(avals[j], b, scalars, n, m, i, j)); } c[i] = kbuff._sum; lnnz += (c[i]!=0) ? 1 : 0; @@ -357,7 +357,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl kbuff.set(0, 0); for(int j=0; j<n; j++) { double valij = (sblock != null) ? sblock.get(i, j) : 0; - kplus.execute2( kbuff, genexecDense(valij, b, scalars, n, m, i, j)); + kplus.execute2( kbuff, genexec(valij, b, scalars, n, m, i, j)); } c[i] = kbuff._sum; lnnz += (c[i]!=0) ? 1 : 0; @@ -368,7 +368,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl return lnnz; } - protected abstract double genexecDense( double a, double[][] b, double[] scalars, int n, int m, int rowIndex, int colIndex); + protected abstract double genexec( double a, double[][] b, double[] scalars, int n, int m, int rowIndex, int colIndex); private class ParAggTask implements Callable<Double> { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/runtime/codegen/SpoofOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofOperator.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofOperator.java index f4ee915..8cb6692 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofOperator.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofOperator.java @@ -56,17 +56,20 @@ public abstract class SpoofOperator implements Serializable } protected double[][] prepInputMatrices(ArrayList<MatrixBlock> inputs) { - return prepInputMatrices(inputs, 1); + return prepInputMatrices(inputs, 1, inputs.size()-1); } protected double[][] prepInputMatrices(ArrayList<MatrixBlock> inputs, int offset) { - double[][] b = new double[inputs.size()-offset][]; - for(int i=offset; i < inputs.size(); i++) { - //allocate dense block in place for empty blocks - if( inputs.get(i).isEmptyBlock(false) && !inputs.get(i).isAllocated() ) - inputs.get(i).allocateDenseBlock(); - //convert sparse to dense temporary block - if( inputs.get(i).isInSparseFormat() ) { + return prepInputMatrices(inputs, offset, inputs.size()-offset); + } + + protected double[][] prepInputMatrices(ArrayList<MatrixBlock> inputs, int offset, int len) { + double[][] b = new double[len][]; + for(int i=offset; i<offset+len; i++) { + //convert empty or sparse to dense temporary block (note: we don't do + //this in place because this block might be used by multiple threads) + if( (inputs.get(i).isEmptyBlock(false) && !inputs.get(i).isAllocated()) + || inputs.get(i).isInSparseFormat() ) { MatrixBlock tmp = inputs.get(i); b[i-offset] = DataConverter.convertToDoubleVector(tmp); LOG.warn("Converted "+tmp.getNumRows()+"x"+tmp.getNumColumns() + @@ -77,6 +80,7 @@ public abstract class SpoofOperator implements Serializable b[i-offset] = inputs.get(i).getDenseBlock(); } } + return b; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/runtime/codegen/SpoofOuterProduct.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofOuterProduct.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofOuterProduct.java index a23ea5a..9f4a6de 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofOuterProduct.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofOuterProduct.java @@ -69,6 +69,7 @@ public abstract class SpoofOuterProduct extends SpoofOperator throw new RuntimeException("Invalid input arguments."); //input preparation + double[][] ab = prepInputMatrices(inputs, 1, 2); double[][] b = prepInputMatrices(inputs, 3); double[] scalars = prepInputScalars(scalarObjects); @@ -77,18 +78,14 @@ public abstract class SpoofOuterProduct extends SpoofOperator final int n = inputs.get(0).getNumColumns(); final int k = inputs.get(1).getNumColumns(); // rank - //public static void matrixMultWDivMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock mX, MatrixBlock ret, WDivMMType wt, int k) MatrixBlock a = inputs.get(0); - MatrixBlock u = inputs.get(1); - MatrixBlock v = inputs.get(2); - MatrixBlock out = new MatrixBlock(1, 1, false); out.allocateDenseBlock(); if(!a.isInSparseFormat()) - executeCellwiseDense(a.getDenseBlock(), u.getDenseBlock(), v.getDenseBlock(), b, scalars, out.getDenseBlock(), n, m, k, _outerProductType, 0, m, 0, n); + executeCellwiseDense(a.getDenseBlock(), ab[0], ab[1], b, scalars, out.getDenseBlock(), n, m, k, _outerProductType, 0, m, 0, n); else - executeCellwiseSparse(a.getSparseBlock(), u.getDenseBlock(), v.getDenseBlock(), b, scalars, out, n, m, k, (int) a.getNonZeros(), _outerProductType, 0, m, 0, n); + executeCellwiseSparse(a.getSparseBlock(), ab[0], ab[1], b, scalars, out, n, m, k, (int) a.getNonZeros(), _outerProductType, 0, m, 0, n); return new DoubleObject(out.getDenseBlock()[0]); } @@ -101,6 +98,7 @@ public abstract class SpoofOuterProduct extends SpoofOperator throw new RuntimeException("Invalid input arguments."); //input preparation + double[][] ab = prepInputMatrices(inputs, 1, 2); double[][] b = prepInputMatrices(inputs, 3); double[] scalars = prepInputScalars(scalarObjects); @@ -118,7 +116,7 @@ public abstract class SpoofOuterProduct extends SpoofOperator //for wdivmm-right, parallelization over rows; both ensure disjoint results) int blklen = (int)(Math.ceil((double)m/numThreads)); for( int i=0; i<numThreads & i*blklen<m; i++ ) - tasks.add(new ParOuterProdAggTask(inputs.get(0), inputs.get(1).getDenseBlock(), inputs.get(2).getDenseBlock(), b, scalars, n, m, k, _outerProductType, i*blklen, Math.min((i+1)*blklen,m), 0, n)); + tasks.add(new ParOuterProdAggTask(inputs.get(0), ab[0], ab[1], b, scalars, n, m, k, _outerProductType, i*blklen, Math.min((i+1)*blklen,m), 0, n)); //execute tasks List<Future<Double>> taskret = pool.invokeAll(tasks); pool.shutdown(); @@ -164,6 +162,7 @@ public abstract class SpoofOuterProduct extends SpoofOperator } //input preparation + double[][] ab = prepInputMatrices(inputs, 1, 2); double[][] b = prepInputMatrices(inputs, 3); double[] scalars = prepInputScalars(scalarObjects); @@ -173,23 +172,21 @@ public abstract class SpoofOuterProduct extends SpoofOperator final int k = inputs.get(1).getNumColumns(); // rank MatrixBlock a = inputs.get(0); - MatrixBlock u = inputs.get(1); - MatrixBlock v = inputs.get(2); switch(_outerProductType) { case LEFT_OUTER_PRODUCT: case RIGHT_OUTER_PRODUCT: if( !a.isInSparseFormat() ) - executeDense(a.getDenseBlock(), u.getDenseBlock(), v.getDenseBlock(), b, scalars, out.getDenseBlock(), n, m, k, _outerProductType, 0, m, 0, n); + executeDense(a.getDenseBlock(), ab[0], ab[1], b, scalars, out.getDenseBlock(), n, m, k, _outerProductType, 0, m, 0, n); else - executeSparse(a.getSparseBlock(), u.getDenseBlock(), v.getDenseBlock(), b, scalars, out.getDenseBlock(), n, m, k, (int) a.getNonZeros(), _outerProductType, 0, m, 0, n); + executeSparse(a.getSparseBlock(), ab[0], ab[1], b, scalars, out.getDenseBlock(), n, m, k, (int) a.getNonZeros(), _outerProductType, 0, m, 0, n); break; case CELLWISE_OUTER_PRODUCT: if( !a.isInSparseFormat() ) - executeCellwiseDense(a.getDenseBlock(), u.getDenseBlock(), v.getDenseBlock(), b, scalars, out.getDenseBlock(), n, m, k, _outerProductType, 0, m, 0, n); + executeCellwiseDense(a.getDenseBlock(), ab[0], ab[1], b, scalars, out.getDenseBlock(), n, m, k, _outerProductType, 0, m, 0, n); else - executeCellwiseSparse(a.getSparseBlock(), u.getDenseBlock(), v.getDenseBlock(), b, scalars, out, n, m, k, (int) a.getNonZeros(), _outerProductType, 0, m, 0, n); + executeCellwiseSparse(a.getSparseBlock(), ab[0], ab[1], b, scalars, out, n, m, k, (int) a.getNonZeros(), _outerProductType, 0, m, 0, n); break; case AGG_OUTER_PRODUCT: @@ -210,11 +207,11 @@ public abstract class SpoofOuterProduct extends SpoofOperator throw new RuntimeException("Invalid input arguments."); //check empty result - if( (_outerProductType == OutProdType.LEFT_OUTER_PRODUCT && inputs.get(1).isEmptyBlock(false)) //U is empty - || (_outerProductType == OutProdType.RIGHT_OUTER_PRODUCT && inputs.get(2).isEmptyBlock(false)) //V is empty - || (_outerProductType == OutProdType.CELLWISE_OUTER_PRODUCT && inputs.get(0).isEmptyBlock(false))) { //X is empty - out.examSparsity(); //turn empty dense into sparse - return; + if( (_outerProductType == OutProdType.LEFT_OUTER_PRODUCT && inputs.get(1).isEmptyBlock(false)) //U is empty + || (_outerProductType == OutProdType.RIGHT_OUTER_PRODUCT && inputs.get(2).isEmptyBlock(false)) //V is empty + || (_outerProductType == OutProdType.CELLWISE_OUTER_PRODUCT && inputs.get(0).isEmptyBlock(false))) { //X is empty + out.examSparsity(); //turn empty dense into sparse + return; } //input preparation and result allocation (Allocate the output that is set by Sigma2CPInstruction) @@ -236,6 +233,7 @@ public abstract class SpoofOuterProduct extends SpoofOperator } //input preparation + double[][] ab = prepInputMatrices(inputs, 1, 2); double[][] b = prepInputMatrices(inputs, 3); double[] scalars = prepInputScalars(scalarObjects); @@ -254,12 +252,12 @@ public abstract class SpoofOuterProduct extends SpoofOperator if( _outerProductType == OutProdType.LEFT_OUTER_PRODUCT ) { int blklen = (int)(Math.ceil((double)n/numThreads)); for( int j=0; j<numThreads & j*blklen<n; j++ ) - tasks.add(new ParExecTask(inputs.get(0), inputs.get(1).getDenseBlock(), inputs.get(2).getDenseBlock(), b, scalars, out, n, m, k, _outerProductType, 0, m, j*blklen, Math.min((j+1)*blklen, n))); + tasks.add(new ParExecTask(inputs.get(0), ab[0], ab[1], b, scalars, out, n, m, k, _outerProductType, 0, m, j*blklen, Math.min((j+1)*blklen, n))); } else { ///right // cellwise int blklen = (int)(Math.ceil((double)m/numThreads)); for( int i=0; i<numThreads & i*blklen<m; i++ ) - tasks.add(new ParExecTask(inputs.get(0), inputs.get(1).getDenseBlock(), inputs.get(2).getDenseBlock(), b, scalars, out, n, m, k, _outerProductType, i*blklen, Math.min((i+1)*blklen,m), 0, n)); + tasks.add(new ParExecTask(inputs.get(0), ab[0], ab[1], b, scalars, out, n, m, k, _outerProductType, i*blklen, Math.min((i+1)*blklen,m), 0, n)); } List<Future<Long>> taskret = pool.invokeAll(tasks); pool.shutdown(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java index 52bc926..8dda92a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java @@ -78,8 +78,10 @@ public class SpoofCPInstruction extends ComputationCPInstruction for (CPOperand input : _in) { if(input.getDataType()==DataType.MATRIX) inputs.add(ec.getMatrixInput(input.getName())); - else if(input.getDataType()==DataType.SCALAR) + else if(input.getDataType()==DataType.SCALAR) { + //note: even if literal, it might be compiled as scalar placeholder scalars.add(ec.getScalarInput(input.getName(), input.getValueType(), input.isLiteral())); + } } // set the output dimensions to the hop node matrix dimensions http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index 41bcffb..c78613d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -107,6 +107,7 @@ public class SpoofSPInstruction extends SPInstruction bcVars.add(_in[i].getName()); } else if(_in[i].getDataType()==DataType.SCALAR) { + //note: even if literal, it might be compiled as scalar placeholder scalars.add(sec.getScalarInput(_in[i].getName(), _in[i].getValueType(), _in[i].isLiteral())); } } @@ -308,7 +309,7 @@ public class SpoofSPInstruction extends SPInstruction MatrixBlock blkIn = tmp._2(); MatrixIndexes ixOut = ixIn; MatrixBlock blkOut = new MatrixBlock(); - ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, (int)ixIn.getRowIndex()); + ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, ixIn); //execute core operation if(((SpoofCellwise)_op).getCellType()==CellType.FULL_AGG) { @@ -326,13 +327,16 @@ public class SpoofSPInstruction extends SPInstruction return ret.iterator(); } - private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex) + private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, MatrixIndexes ixIn) throws DMLRuntimeException { ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>(); ret.add(blkIn); - for( PartitionedBroadcast<MatrixBlock> vector : _vectors ) - ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1)); + for( PartitionedBroadcast<MatrixBlock> in : _vectors ) { + int rowIndex = (int)((in.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1); + int colIndex = (int)((in.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1); + ret.add(in.getBlock(rowIndex, colIndex)); + } return ret; } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmKMeans.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmKMeans.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmKMeans.java index 83a1c3a..7ed9b2a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmKMeans.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmKMeans.java @@ -159,7 +159,7 @@ public class AlgorithmKMeans extends AutomatedTestBase /* This is for running the junit test the new way, i.e., construct the arguments directly */ String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[]{ "-explain", "hops", "-stats", + programArgs = new String[]{ "-explain", "-stats", "-args", input("X"), String.valueOf(centroids), String.valueOf(runs), String.valueOf(epsilon), String.valueOf(maxiter), output("C")}; @@ -174,13 +174,6 @@ public class AlgorithmKMeans extends AutomatedTestBase runTest(true, false, null, -1); - //no comparison with R due to randomized algorithm - //runRScript(true); - //compare matrices - //HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C"); - //HashMap<CellIndex, Double> rfile = readRMatrixFromFS("C"); - //TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - Assert.assertTrue(heavyHittersContainsSubString("spoof") || heavyHittersContainsSubString("sp_spoof")); } finally { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java index c6f7a85..e5b4a83 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java @@ -137,7 +137,8 @@ public class AlgorithmLinregCG extends AutomatedTestBase HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("w"); HashMap<CellIndex, Double> rfile = readRMatrixFromFS("w"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - Assert.assertTrue(heavyHittersContainsSubString("spoof") || heavyHittersContainsSubString("sp_spoof")); + Assert.assertTrue(heavyHittersContainsSubString("spoofRA") + || heavyHittersContainsSubString("sp_spoofRA")); } finally { rtplatform = platformOld; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index cc17cc6..4a6259a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -157,7 +157,7 @@ public class CellwiseTmplTest extends AutomatedTestBase runTest(true, false, null, -1); runRScript(true); - //System.exit(1); + if(testname.equals(TEST_NAME6)) //tak+ { //compare scalars @@ -172,7 +172,8 @@ public class CellwiseTmplTest extends AutomatedTestBase HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); if( !(rewrites && testname.equals(TEST_NAME2)) ) //sigmoid - Assert.assertTrue(heavyHittersContainsSubString("spoof") || heavyHittersContainsSubString("sp_spoof")); + Assert.assertTrue(heavyHittersContainsSubString("spoofCell") + || heavyHittersContainsSubString("sp_spoofCell")); } } finally { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/test/java/org/apache/sysml/test/integration/functions/codegen/DAGCellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/DAGCellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/DAGCellwiseTmplTest.java index 5eb5f0b..fd5ecca 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/DAGCellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/DAGCellwiseTmplTest.java @@ -152,7 +152,8 @@ public class DAGCellwiseTmplTest extends AutomatedTestBase HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("S"); HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - Assert.assertTrue(heavyHittersContainsSubString("spoof") || heavyHittersContainsSubString("sp_spoof")); + Assert.assertTrue(heavyHittersContainsSubString("spoofCell") + || heavyHittersContainsSubString("sp_spoofCell")); } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrites; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c5eddccf/src/test/java/org/apache/sysml/test/integration/functions/codegen/OuterProdTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/OuterProdTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/OuterProdTmplTest.java index d7b423f..d87ffe5 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/OuterProdTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/OuterProdTmplTest.java @@ -132,8 +132,6 @@ public class OuterProdTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME7, false, ExecType.CP ); } - //TODO - @Test public void testCodegenOuterProdRewrite1_sp() { testCodegenIntegrationWithInput( TEST_NAME1, true, ExecType.SPARK ); @@ -191,7 +189,8 @@ public class OuterProdTmplTest extends AutomatedTestBase HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); if( !rewrites ) - Assert.assertTrue(heavyHittersContainsSubString("spoof") || heavyHittersContainsSubString("sp_spoof")); + Assert.assertTrue(heavyHittersContainsSubString("spoofOP") + || heavyHittersContainsSubString("sp_spoofOP")); } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; @@ -233,22 +232,22 @@ public class OuterProdTmplTest extends AutomatedTestBase runTest(true, false, null, -1); runRScript(true); - if(testname.equals(TEST_NAME4)) //wcemm - { + if(testname.equals(TEST_NAME4)) { //wcemm //compare scalars HashMap<CellIndex, Double> dmlfile = readDMLScalarFromHDFS("S"); HashMap<CellIndex, Double> rfile = readRScalarFromFS("S"); TestUtils.compareScalars((Double) dmlfile.values().toArray()[0], (Double) rfile.values().toArray()[0],0.0001); } - else - { + else { //compare matrices HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("S"); HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - if( !rewrites ) - Assert.assertTrue(heavyHittersContainsSubString("spoof") || heavyHittersContainsSubString("sp_spoof")); } + + if( !rewrites ) + Assert.assertTrue(heavyHittersContainsSubString("spoofOP") + || heavyHittersContainsSubString("sp_spoofOP")); } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
