http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/CplanRegister.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CplanRegister.java b/src/main/java/org/apache/sysml/hops/codegen/template/CplanRegister.java new file mode 100644 index 0000000..a4bcffe --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CplanRegister.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.template; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map.Entry; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeCell; +import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeRowAggVector; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.hops.codegen.template.BaseTpl.TemplateType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.matrix.data.Pair; +import org.apache.sysml.utils.Statistics; + +public class CplanRegister { + + //HashMap: key: TemplateType - Value: List of all the patterns fused by that template + //LinkedHashMap: key: HopID of the original hop to be fused , Value: Input hops to the fused operation + //Note: LinkedHashMap holds intermediate cplans as well (e.g, log(exp(round(X))) ) We store in the LinkedHashMao three keys + //for the three hops (log, exp and round). The key that was inserted last is the key of the hop to be fused + + private HashMap<TemplateType, ArrayList<LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>>> _cplans; + + public CplanRegister() { + _cplans = new HashMap<TemplateType, ArrayList<LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>>>(); + } + + public void insertCpplans(TemplateType type, LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> cplans) { + if( !_cplans.containsKey(type) ) + _cplans.put(type, new ArrayList<LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>>()); + + _cplans.get(type).add(cplans); + + if( DMLScript.STATISTICS ) + Statistics.incrementCodegenCPlanCompile(1); + //note: cplans.size() would also contain all subsets of cpplans + } + + public boolean containsHop(TemplateType type, long hopID) { + if(!_cplans.containsKey(type)) + return false; + for (LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> cpplans : _cplans.get(type) ) + if(cpplans.containsKey(hopID)) + return true; + + return false; + } + + public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> getTopLevelCplans() + { + if( _cplans.isEmpty() ) + return new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>(); + + //resolve conflicts, i.e., overlap, between template types + resolvePlanConflicts(); + + //extract top level (subsuming) cplans per type and operator chain + LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> ret = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>(); + for (TemplateType key : _cplans.keySet()) { + for (LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> intermediateCplans : _cplans.get(key)) { + Entry<Long, Pair<Hop[],CNodeTpl>> cplan = TemplateUtils.getTopLevelCpplan(intermediateCplans); + if(cplan !=null) + ret.put(cplan.getKey(), cplan.getValue()); + } + } + + //merge top level plans if possible //TODO move to rowagg template + ret = mergeRowAggregateCellwisePlans(ret); + + return ret; + } + + /** + * Resolves conflicts between overlapping cplans of different types. + * + */ + private void resolvePlanConflicts() + { + //get different plan categories + ArrayList<LinkedHashMap<Long, Pair<Hop[], CNodeTpl>>> cellwisePlans = _cplans.get(TemplateType.CellTpl); + ArrayList<LinkedHashMap<Long, Pair<Hop[], CNodeTpl>>> outerprodPlans = _cplans.get(TemplateType.OuterProductTpl); + ArrayList<LinkedHashMap<Long, Pair<Hop[], CNodeTpl>>> rowaggPlans = _cplans.get(TemplateType.RowAggTpl); + + //prefer outer product plans over cellwise plans -> remove overlap + if( cellwisePlans != null && outerprodPlans != null ) { + for( LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> outerprodCplan : outerprodPlans ) { + for( LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> map : cellwisePlans ) + for( Long key : outerprodCplan.keySet() ) + map.remove(key); + } + } + + //prefer row aggregate plans over cellwise plans -> remove overlap + if( cellwisePlans != null && rowaggPlans != null ) { + for( LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> rowaggCplan : rowaggPlans ) { + for( LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> map : cellwisePlans ) + for( Long key : rowaggCplan.keySet() ) + map.remove(key); + } + } + } + + private static LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> mergeRowAggregateCellwisePlans(LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> plans) + { + LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> ret = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>(plans); + + //extract row aggregate templates + HashMap<Long, Pair<Hop[],CNodeTpl>> rowaggPlans = new HashMap<Long, Pair<Hop[],CNodeTpl>>(); + for( Entry<Long, Pair<Hop[],CNodeTpl>> e : plans.entrySet() ) + if( e.getValue().getValue() instanceof CNodeRowAggVector ) + rowaggPlans.put(e.getKey(), e.getValue()); + + //probe and merge row aggregate secondary inputs (by definition vectors) + for( Entry<Long, Pair<Hop[],CNodeTpl>> e : rowaggPlans.entrySet() ) { + //check all inputs for existing cell plans + Hop[] inputs = e.getValue().getKey(); + for( int i=1; i<inputs.length; i++ ) { + long inhopID = inputs[i].getHopID(); + if( ret.containsKey(inhopID) && ret.get(inhopID).getValue() instanceof CNodeCell + && !((CNodeCell)ret.get(inhopID).getValue()).hasMultipleConsumers() ) + { + //merge row agg template + CNodeRowAggVector rowaggtpl = (CNodeRowAggVector) e.getValue().getValue(); + CNodeCell celltpl = (CNodeCell)ret.get(inhopID).getValue(); + celltpl.getInput().get(0).setDataType(DataType.MATRIX); + rowaggtpl.rReplaceDataNode(rowaggtpl.getOutput(), inhopID, celltpl.getOutput()); + rowaggtpl.rInsertLookupNode(rowaggtpl.getOutput(), + ((CNodeData)celltpl.getInput().get(0)).getHopID(), new HashMap<Long, CNode>()); + for( CNode input : celltpl.getInput() ) + rowaggtpl.addInput(input); + HashSet<Long> inputIDs = TemplateUtils.rGetInputHopIDs(rowaggtpl.getOutput(), new HashSet<Long>()); + Hop[] hops = TemplateUtils.mergeDistinct(inputIDs, inputs, ret.get(inhopID).getKey()); + e.getValue().setKey(hops); + + //remove cell template + ret.remove(inhopID); + } + } + } + + return ret; + } +}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java new file mode 100644 index 0000000..c202d3c --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.template; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; + +import org.apache.sysml.api.DMLException; +import org.apache.sysml.hops.AggBinaryOp; +import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.ReOrgOp; +import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; +import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; +import org.apache.sysml.runtime.matrix.data.Pair; + +public class OuterProductTpl extends BaseTpl { + + public OuterProductTpl() { + super(TemplateType.OuterProductTpl); + } + + private List<OpOp2> sparseDrivers = new ArrayList<OpOp2>(Arrays.asList(OpOp2.MULT, OpOp2.DIV)); + private OutProdType _outerProductType = null; + private boolean _transposeOutput = false; + private boolean _transposeInput = false; + + @Override + public boolean openTpl(Hop hop) { + // outerproduct ( output dimensions is greater than the common dimension) + return ( hop instanceof AggBinaryOp && ((AggBinaryOp)hop).isMatrixMultiply() && hop.dimsKnown() + && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() + && (hop.getDim1() > hop.getInput().get(0).getDim2() && hop.getDim2() > hop.getInput().get(1).getDim1()) ); + } + + @Override + public boolean findTplBoundaries(Hop h, CplanRegister cplanRegister) { + _endHop = h;//outerProduct tpl starts with endHop + HashMap<String,Hop> uniqueMatrixInputs = new HashMap<String,Hop>(); + uniqueMatrixInputs.put("U", h.getInput().get(0)); + if( h.getInput().get(1) instanceof ReorgOp && ((ReorgOp)h.getInput().get(1)).getOp() == ReOrgOp.TRANSPOSE ) + uniqueMatrixInputs.put("V", h.getInput().get(1).getInput().get(0)); + else + { + _transposeInput = true; // we need to transpose V to be tall and skinny + uniqueMatrixInputs.put("V", h.getInput().get(1)); + } + rfindOuterProduct(_endHop, _endHop, uniqueMatrixInputs, h.getDim1(), h.getDim2(), new HashSet<Long>()); + + if(uniqueMatrixInputs.size() == 3 && _initialHop != null && _initialHop != _endHop ) //sanity check + { + //check if added matrices can be inferred from input matrices for example (X!=0) or abs(X) are not different from X + Hop commonChild = null; + if(! _adddedMatrices.isEmpty() ) { + //if addedMatrices does not have a common child with input X then do not compile + commonChild = TemplateUtils.commonChild(_adddedMatrices,uniqueMatrixInputs.get("X")); + if(commonChild == null ) // there are multiple matrices involved other than X + return false; + } + if(commonChild != null) { + _matrixInputs.add(commonChild); //add common child as the major input matrix + _adddedMatrices.add(uniqueMatrixInputs.get("X")); // put unique matrix as one of the additional matrices that is a chain of cell wise operations for the input matrix + } + else { + _matrixInputs.add(uniqueMatrixInputs.get("X")); //major matrix is the sparse driver + } + _matrixInputs.add(uniqueMatrixInputs.get("U")); + + if(_transposeInput) { + ReorgOp transposeV = HopRewriteUtils.createTranspose(uniqueMatrixInputs.get("V")); + //ReorgOp transposeV = new ReorgOp("", uniqueMatrixInputs.get("V").getDataType(), uniqueMatrixInputs.get("V").getValueType(), ReOrgOp.TRANSPOSE, uniqueMatrixInputs.get("V")); + _matrixInputs.add(transposeV); + } + else { + _matrixInputs.add(uniqueMatrixInputs.get("V")); + } + + + //add also added matrices so that they can be interpreted as inputs + for(Hop addedMatrix : _adddedMatrices) + if(!_matrixInputs.contains(addedMatrix)) + _matrixInputs.add(addedMatrix); + + //add the children of _endHop ( this will handle the case for wdivmm right when I add the both t(V) and V as inputs + for (Hop hop: _endHop.getInput()) + _matrixInputs.add(hop); + + return true; + } + else + return false; + + } + private void rfindOuterProduct(Hop child, Hop h, HashMap<String,Hop> uniqueMatrixInputs, long outerProductDim1, long outerProductDim2, HashSet<Long> memo) + { + if(memo.contains(h.getHopID())) + return; + + if( ( h instanceof UnaryOp || h instanceof BinaryOp ) //unary operation or binary operation + && h.getDataType() == DataType.MATRIX // Output is a matrix + && h.getDim1() == outerProductDim1 && h.getDim2() == outerProductDim2 // output is the same size as the matrix + && TemplateUtils.isOperationSupported(h)) // operation is supported in codegen + { + if(h instanceof BinaryOp) + { + + // find the other child rather than the one that called the parent + Hop otherChild = h.getInput().get(0) != child ? h.getInput().get(0) : h.getInput().get(1); + + //if scalar or vector then we fuse it similar to the way we fuse celltpl, + if(TemplateUtils.isVectorOrScalar(otherChild)) + { + _initialHop = h; + _outerProductType = OutProdType.CELLWISE_OUTER_PRODUCT; + + } + // other child is a matrix + else + { + //if the binary operation is sparse safe (mult, div) + if(sparseDrivers.contains(((BinaryOp)h).getOp()) ) + { + if(!uniqueMatrixInputs.containsKey("X")) + { + //extra sanity check + if(otherChild.getDim1() == outerProductDim1 && otherChild.getDim2() == outerProductDim2) { + uniqueMatrixInputs.put("X", otherChild); + _initialHop = h; + } + else { //matrix size does not match what is expected for X + return; + } + } + } + else { + _adddedMatrices.add(otherChild); + } + } + } + } + + if( h instanceof AggBinaryOp && ((AggBinaryOp) h).isMatrixMultiply() && h != child) //make sure that the AggBinaryOp is not the same as the outerproduct that triggered this method + { + if(memo.contains(h.getInput().get(0).getHopID())) { // if current node is the parent for the left child then it is right matrix multiply + + if (h.getInput().get(1) == uniqueMatrixInputs.get("V") )//right operand is V + { + _initialHop = h; + _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT; + return; + } + //right operand is t(V) + else if(h.getInput().get(1) instanceof ReorgOp && ((ReorgOp)h.getInput().get(1)).getOp() == ReOrgOp.TRANSPOSE && h.getInput().get(1).getInput().get(0) == uniqueMatrixInputs.get("V") ) + { + //replace V with T(V) + uniqueMatrixInputs.put("V", h.getInput().get(1)); + _transposeInput = false; //no need to transpose Input + _initialHop = h; + _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT; + return; + } + else + { + _initialHop = h.getInput().get(0); // set the child that was processed + return; + } + } + else {//left matrix multiply + + //left is T(U) + if (h.getInput().get(0) instanceof ReorgOp && ((ReorgOp)h.getInput().get(0)).getOp() == ReOrgOp.TRANSPOSE && h.getInput().get(0).getInput().get(0) == uniqueMatrixInputs.get("U") ) + { + _initialHop = h; + _outerProductType = OutProdType.LEFT_OUTER_PRODUCT; + //T(T(U) %*% ..) + for(Hop hParent : h.getParent()) + if(hParent instanceof ReorgOp && ((ReorgOp)hParent).getOp() == ReOrgOp.TRANSPOSE) { + _initialHop = hParent; // set the transpose hop + return; + } + _transposeOutput = true; + return; + } + else { + _initialHop = h.getInput().get(1); // set the child that was processed + return; + } + } + } + + if( h instanceof AggUnaryOp && ((AggUnaryOp) h).getOp() == AggOp.SUM + && ((AggUnaryOp) h).getDirection() == Direction.RowCol) + { + _initialHop = h; + _outerProductType = OutProdType.AGG_OUTER_PRODUCT; + return; + } + + memo.add(h.getHopID()); + //process parents recursively + for( Hop parent : h.getParent()) + rfindOuterProduct(h, parent,uniqueMatrixInputs, outerProductDim1,outerProductDim2, memo); + } + + ////////////////Helper methods for finding boundaries + private OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out) + { + if (_outerProductType != null) + return _outerProductType; + + + //extra checks to infer type + if (out.getDataType() == DataType.SCALAR) // sum + { + _outerProductType = OutProdType.AGG_OUTER_PRODUCT; + } + else if( isDimsEqual(out,V) && out instanceof ReorgOp) // the second condition is added because sometimes V and U might be same dimensions if the dims of X are equal + { + _outerProductType = OutProdType.LEFT_OUTER_PRODUCT; + } + else if( isDimsEqual(out,U)) + { + _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT; + } + else if ( isDimsEqual(out,X) ) + { + _outerProductType = OutProdType.CELLWISE_OUTER_PRODUCT; + } + + return _outerProductType; + } + + private static boolean isDimsEqual(Hop hop1, Hop hop2) + { + if(hop1.getDim1() == hop2.getDim1() && hop1.getDim2() == hop2.getDim2()) + return true; + return false; + } + + @Override + public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals) throws DMLException { + + //re-assign the dimensions of inputs to match the generated code dimensions + + //matrix X is a scalar in the generated code + _initialCnodes.add(new CNodeData(_matrixInputs.get(0), 1,1,DataType.SCALAR)); + + //matrix V + _initialCnodes.add(new CNodeData(_matrixInputs.get(1), 1,(int)_matrixInputs.get(1).getDim2(), DataType.MATRIX)); + + //matrix V + _initialCnodes.add(new CNodeData(_matrixInputs.get(2), 1,(int)_matrixInputs.get(2).getDim2(),DataType.MATRIX)); + + rConstructOuterProdCplan(_initialHop, _initialHop, new HashSet<Long>(), compileLiterals); + return _cpplans; + } + + private void rConstructOuterProdCplan(Hop root, Hop hop, HashSet<Long> memo, boolean compileLiterals) throws DMLException + { + if( memo.contains(hop.getHopID()) ) + return; + //process childs recursively + for( Hop c : hop.getInput() ) + rConstructOuterProdCplan(root, c, memo, compileLiterals); + + //organize the main inputs + Hop X, U, V; + X = _matrixInputs.get(0); + U = _matrixInputs.get(1); + V = _matrixInputs.get(2); + if(hop==_endHop) + _endHopReached = true; + + // first hop to enter here should be _endHop + if(TemplateUtils.inputsAreGenerated(hop,_matrixInputs,_cpplans) && _endHopReached) // if direct children are DataGenOps, literals, or already in the cpplans then we are ready to generate code + { + CNodeOuterProduct outerProdTmpl = null; + + //Fetch operands + CNode out = null; + ArrayList<CNode> addedCNodes = new ArrayList<CNode>(); + ArrayList<Hop> addedHops = new ArrayList<Hop>(); + ArrayList<CNode> cnodeData = TemplateUtils.fetchOperands(hop, _cpplans, addedCNodes, addedHops, _initialCnodes, compileLiterals); + + //if operands are scalar or independent from X + boolean independentOperands = hop != root && (hop.getDataType() == DataType.SCALAR || TemplateUtils.isOperandsIndependent(cnodeData, addedHops, new String[]{_matrixInputs.get(0).getName(),_matrixInputs.get(1).getName(),_matrixInputs.get(2).getName()})); + if(!independentOperands) + { + if(hop instanceof UnaryOp) + { + CNode cdata1 = cnodeData.get(0); + + //Primitive Operation has the same name as Hop Type OpOp1 + String primitiveOpName = ((UnaryOp)hop).getOp().toString(); + out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); + } + else if(hop instanceof BinaryOp) + { + CNode cdata1 = cnodeData.get(0); + CNode cdata2 = cnodeData.get(1); + + //Primitive Operation has the same name as Hop Type OpOp2 + String primitiveOpName = ((BinaryOp)hop).getOp().toString(); + + if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) ) + { + //second argument is always the vector + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP); + //out = new CNodeBinary(tmp, cdata2, BinType.valueOf(primitiveOpName)); + } + //cdata1 is a matrix + else if ( (cdata1.getNumRows() > 1 && cdata1.getNumCols() > 1) ) + { + CellTpl cellTpl = new CellTpl(); + cdata1 = cellTpl.fuseCellWise(hop.getInput().get(0), _matrixInputs.get(0), compileLiterals); // second argument is always matrix X + if (cdata1 == null) + return; + } + //cdata2 is vector + //else if( cdata2 instanceof CNodeData && (((CNodeData)cdata2).getNumRows() > 1 && ((CNodeData)cdata2).getNumCols() == 1) || ( ((CNodeData)cdata2).getNumRows() == 1 && ((CNodeData)cdata2).getNumCols() > 1 )) + if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) ) + { + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP); + //out = new CNodeBinary(cdata1, tmp, BinType.valueOf(primitiveOpName)); + } + //cdata2 is a matrix + else if ( (cdata2.getNumRows() > 1 && cdata2.getNumCols() > 1) ) + { + CellTpl cellTpl = new CellTpl(); + cdata2 = cellTpl.fuseCellWise(hop.getInput().get(1), _matrixInputs.get(0), compileLiterals); // second argument is always matrix X + if (cdata2 == null) + return; + } + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); + } + else if(hop instanceof AggBinaryOp) + { + CNode cdata1 = cnodeData.get(0); + CNode cdata2 = cnodeData.get(1); // remember that we already fetched what is under transpose + + //outerproduct U%*%t(V) then we should have passsed in V as the input + if(hop.getInput().get(0) == U && hop.getInput().get(1) instanceof ReorgOp && hop.getInput().get(1).getInput().get(0) == V) + { + //re-assign cdata2 to read V instead of t(V) + cdata2 = _initialCnodes.get(2); // the initialCNodes holds V + out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); + } + + //outerproduct U%*%V then we should have passsed in trnasposeV as the input + else if(hop.getInput().get(0) == U && V instanceof ReorgOp && V.getInput().get(0)== hop.getInput().get(1)) + { + //re-assign cdata2 to read t(V) instead of V + cdata2 = _initialCnodes.get(2); // the initialCNodes holds transpose of V + out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); + } + //outerproduct U%*%V but not right wdivmm so we did not pass T(V) + else if(hop.getInput().get(0) == U && hop.getInput().get(1) == V ) + { + //re-assign cdata2 to read t(V) instead of V + cdata2 = _initialCnodes.get(2); // the initialCNodes holds transpose of V + out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); + } + + //left outerproduct (i.e., left operand is T(U) ) + else if(hop.getInput().get(0) instanceof ReorgOp && hop.getInput().get(0).getInput().get(0) == U) + { + //scalar is cdata2 + out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD); + } + + //right outerproduct (i.e., right operand is V ) + else if(hop.getInput().get(1) != U && hop.getInput().get(1) == V) + { + cdata2 = _initialCnodes.get(2); + out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD); + } + + //right outerproduct (i.e., right operand is t(V) ) + else if(hop.getInput().get(1) instanceof ReorgOp && hop.getInput().get(1).getInput().get(0) == V) + { + cdata2 = _initialCnodes.get(2); + out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD); + } + } + else if ( hop instanceof ReorgOp && ((ReorgOp)hop).getOp() == ReOrgOp.TRANSPOSE && root == hop) // if transpose wire the oinput in T( T(U ...) + { + out = cnodeData.get(0); + } + else if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM && root == hop + && ((AggUnaryOp)hop).getDirection() == Direction.RowCol ) + { + out = cnodeData.get(0); + } + } + // wire output to the template + if(out != null || independentOperands) + { + if(_cpplans.isEmpty()) + { + //first initialization has to have the first variable as input + ArrayList<CNode> initialInputs = new ArrayList<CNode>(); + + if(independentOperands) // pass the hop itself as an input instead of its children + { + CNode c = new CNodeData(hop); + initialInputs.addAll(_initialCnodes); + initialInputs.add(c); + outerProdTmpl = new CNodeOuterProduct(initialInputs, c); + outerProdTmpl.setOutProdType(getOuterProductType(X, U, V, root)); + outerProdTmpl.setTransposeOutput(_transposeOutput); + _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(new Hop[] {X,U,V,hop} ,outerProdTmpl)); + } + else + { + initialInputs.addAll(_initialCnodes); + initialInputs.addAll(cnodeData); + outerProdTmpl = new CNodeOuterProduct(initialInputs, out); + outerProdTmpl.setOutProdType(getOuterProductType(X, U, V, root)); + outerProdTmpl.setTransposeOutput(_transposeOutput); + + Hop[] hopArray = new Hop[addedHops.size()+3]; + hopArray[0] = X; + hopArray[1] = U; + hopArray[2] = V; + + System.arraycopy( addedHops.toArray(), 0, hopArray, 3, addedHops.size()); + + _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(hopArray,outerProdTmpl)); + } + } + else + { + if(independentOperands) + { + CNode c = new CNodeData(hop); + //clear Operands + addedCNodes.clear(); + addedHops.clear(); + + //added the current hop as the input + addedCNodes.add(c); + addedHops.add(hop); + out = c; + } + + //wire the output to existing or new template + TemplateUtils.setOutputToExistingTemplate(hop, out, _cpplans, addedCNodes, addedHops); + } + } + memo.add(hop.getHopID()); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java new file mode 100644 index 0000000..0aff9ae --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.template; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.LinkedHashMap; + +import org.apache.sysml.api.DMLException; +import org.apache.sysml.hops.AggBinaryOp; +import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.DataGenOp; +import org.apache.sysml.hops.DataOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; +import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeRowAggVector; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.Hop.ReOrgOp; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.matrix.data.Pair; + +public class RowAggTpl extends BaseTpl { + + public RowAggTpl() { + super(TemplateType.RowAggTpl); + } + + @Override + public boolean openTpl(Hop hop) { + if ( (hop instanceof AggBinaryOp || hop instanceof AggUnaryOp) // An aggregate operation + && ( (hop.getDim1()==1 && hop.getDim2()!=1) || (hop.getDim1()!=1 && hop.getDim2()==1) ) )// the output is a vector + return true; + return false; + } + + @Override + public boolean findTplBoundaries(Hop initialHop, CplanRegister cplanRegister) { + _initialHop = initialHop; + if(initialHop instanceof AggBinaryOp) { + // for simplicity we assume that the first operand should be t(X) however, it could be later on W.T(X) + if(initialHop.getInput().get(0) instanceof ReorgOp && ((ReorgOp)initialHop.getInput().get(0)).getOp()== ReOrgOp.TRANSPOSE ) + _matrixInputs.add(initialHop.getInput().get(0).getInput().get(0)); //add what is under the transpose + else + return false; + } + rFindRowAggPattern(initialHop, new HashSet<Long>()); + + if(cplanRegister.containsHop(TemplateType.RowAggTpl, initialHop.getHopID())) + return false; + + return (_endHop != null); + } + + + private void rFindRowAggPattern(Hop h, HashSet<Long> memo) + { + if(memo.contains(h.getHopID()) || h.getDataType() == DataType.SCALAR + || h instanceof DataOp || h instanceof DataGenOp || h instanceof LiteralOp) { + return; + } + + boolean continueTraversing = false; + if (h instanceof AggBinaryOp) + { + if(h != _initialHop) { + //T(X) % ..... X %*% v ,check that X is the same as what we saw previously under transpose + if( h.getInput().get(0).equals(_matrixInputs.get(0)) && TemplateUtils.isVector(h.getInput().get(1)) ) { + _endHop = h; + } + } + else { + continueTraversing = true; + } + } + // if initial hop is colSums continue + else if(h instanceof AggUnaryOp && (((AggUnaryOp)_initialHop).getDirection() == Direction.Col && ((AggUnaryOp)_initialHop).getOp() == AggOp.SUM ) && h == _initialHop) + { + continueTraversing=true; + } + //rowSums(X) + else if(h instanceof AggUnaryOp && ((AggUnaryOp)h).getDirection() == Direction.Row && ((AggUnaryOp)h).getOp() == AggOp.SUM ) + { + // check if root pattern is colsums + if((((AggUnaryOp)_initialHop).getDirection() == Direction.Col && ((AggUnaryOp)_initialHop).getOp() == AggOp.SUM )) + { + + //TODO Now the pattern is limited to finding rowSums + _matrixInputs.add(h.getInput().get(0)); + _endHop = h; + } + } + // unary operation || binary operation with first input as a matrix || binary operation with second input as a matrix + else if( ( h instanceof UnaryOp || (h instanceof BinaryOp && h.getInput().get(0).getDataType() == DataType.MATRIX && TemplateUtils.isVectorOrScalar(h.getInput().get(1))) || (h instanceof BinaryOp && TemplateUtils.isVectorOrScalar(h.getInput().get(0)) && h.getInput().get(1).getDataType() == DataType.MATRIX) ) //unary operation or binary operaiton with one matrix and a scalar + && h.getDataType() == DataType.MATRIX // Output is a matrix + && TemplateUtils.isOperationSupported(h) ) //Operation is supported in codegen + { + continueTraversing = true; + } + + //check if we should continue traversing + if(!continueTraversing) + { + return; // stop traversing if conditions does not apply + } + else + { + //process childs recursively + for( Hop in : h.getInput() ) + rFindRowAggPattern(in,memo); + } + memo.add(h.getHopID()); + } + + @Override + public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals) + throws DMLException { + + //re-assign the dimensions of inputs to match the generated code dimensions + _initialCnodes.add(new CNodeData(_matrixInputs.get(0))); + + rConstructRowAggCplan(_initialHop,_initialHop,new HashSet<Long>(), compileLiterals); + return _cpplans; + } + + private void rConstructRowAggCplan(Hop root, Hop hop, HashSet<Long> memo, boolean compileLiterals) throws DMLException + { + if( memo.contains(hop.getHopID()) ) + return; + //process childs recursively + for( Hop c : hop.getInput() ) + rConstructRowAggCplan(root, c, memo, compileLiterals); + if(hop == _endHop) + _endHopReached = true; + + // first hop to enter here should be _endHop + if(TemplateUtils.inputsAreGenerated(hop,_matrixInputs,_cpplans) && _endHopReached) // if direct children are DataGenOps, literals, or already in the cpplans then we are ready to generate code + { + CNodeRowAggVector rowTmpl = null; + + //Fetch operands + CNode out = null; + ArrayList<CNode> addedCNodes = new ArrayList<CNode>(); + ArrayList<Hop> addedHops = new ArrayList<Hop>(); + ArrayList<CNode> cnodeData = TemplateUtils.fetchOperands(hop, _cpplans, addedCNodes, addedHops, _initialCnodes, compileLiterals); + + //if operands are scalar or independent from X + boolean independentOperands = hop.getDataType() == DataType.SCALAR + || TemplateUtils.isOperandsIndependent(cnodeData, addedHops, new String[] {_matrixInputs.get(0).getName()}); + + if(!independentOperands) + { + + if(hop instanceof AggUnaryOp) + { + CNode cdata1 = cnodeData.get(0); + //set the out cnode based on the operation + if( ((AggUnaryOp)hop).getDirection() == Direction.Row && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) //RowSums + { + if(hop.getInput().get(0).getDim2()==1) + out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP); + else + out = new CNodeUnary(cdata1, UnaryType.ROW_SUMS); + } + // if colsums is the root hop, wire the input to the out because colsums it is done automatically by the template + else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM && root == hop) + { + //vector div add without temporary copy + if(cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType()==BinType.VECT_DIV_SCALAR) + out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), BinType.VECT_DIV_ADD); + else + out = cdata1; + } + } + else if(hop instanceof AggBinaryOp) + { + //Fetch operands specific to the operation + CNode cdata1 = cnodeData.get(0); + CNode cdata2 = cnodeData.get(1); + + //choose the operation based on the transpose + if( hop.getInput().get(0) instanceof ReorgOp && ((ReorgOp)hop.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE ) + { + //fetch the data inside the transpose + //cdata1 = new CNodeData(hop.getInput().get(0).getInput().get(0).getName(), (int)hop.getInput().get(0).getInput().get(0).getDim1(), (int)hop.getInput().get(0).getInput().get(0).getDim2()); + out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD); + } + else + { + if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1) + out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), + (cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT); + else + out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); + } + } + else if(hop instanceof BinaryOp) + { + CNode cdata1 = cnodeData.get(0); + CNode cdata2 = cnodeData.get(1); + + // if one input is a matrix then we need to do vector by scalar operations + if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 ) + { + if (((BinaryOp)hop).getOp()== OpOp2.DIV) + //CNode generatedScalar = new CNodeData("1", 0, 0); // generate literal in order to rewrite the div to x * 1/y + //CNode outScalar = new CNodeBinary(generatedScalar, cdata2, BinType.SCALAR_DIVIDE); + //out = new CNodeBinary(outScalar, cdata1, BinType.VECT_MULT_ADD); + out = new CNodeBinary(cdata1, cdata2, BinType.VECT_DIV_SCALAR); + + } + else //one input is a vector/scalar other is a scalar + { + //Primitive Operation has the same name as Hop Type OpOp2 + String primitiveOpName = ((BinaryOp)hop).getOp().toString(); + + if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) ) + { + //second argument is always the vector + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP); + //out = new CNodeBinary(tmp, cdata2, BinType.valueOf(primitiveOpName)); + } + //cdata2 is vector + //else if( cdata2 instanceof CNodeData && (((CNodeData)cdata2).getNumRows() > 1 && ((CNodeData)cdata2).getNumCols() == 1) || ( ((CNodeData)cdata2).getNumRows() == 1 && ((CNodeData)cdata2).getNumCols() > 1 )) + if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) ) + { + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP); + //out = new CNodeBinary(cdata1, tmp, BinType.valueOf(primitiveOpName)); + } + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); + } + + } + + if( out.getDataType().isMatrix() ) { + out.setNumRows(hop.getDim1()); + out.setNumCols(hop.getDim2()); + } + } + // wire output to the template + if(out != null || independentOperands) + { + if(_cpplans.isEmpty()) + { + //first initialization has to have the first variable as input + ArrayList<CNode> initialInputs = new ArrayList<CNode>(); + + if(independentOperands) // pass the hop itself as an input instead of its children + { + CNode c = new CNodeData(hop); + initialInputs.addAll(_initialCnodes); + initialInputs.add(c); + rowTmpl = new CNodeRowAggVector(initialInputs, c); + _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(new Hop[] {_matrixInputs.get(0),hop} ,rowTmpl)); + } + else + { + initialInputs.addAll(_initialCnodes); + initialInputs.addAll(cnodeData); + rowTmpl = new CNodeRowAggVector(initialInputs, out); + + //Hop[] hopArray = new Hop[hop.getInput().size()+1]; + Hop[] hopArray = new Hop[addedHops.size()+1]; + hopArray[0] = _matrixInputs.get(0); + + //System.arraycopy( hop.getInput().toArray(), 0, hopArray, 1, hop.getInput().size()); + System.arraycopy( addedHops.toArray(), 0, hopArray, 1, addedHops.size()); + + _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(hopArray,rowTmpl)); + } + } + else + { + if(independentOperands) + { + CNode c = new CNodeData(hop); + //clear Operands + addedCNodes.clear(); + addedHops.clear(); + + //added the current hop as the input + addedCNodes.add(c); + addedHops.add(hop); + out = c; + } + //wire the output to existing or new template + TemplateUtils.setOutputToExistingTemplate(hop, out, _cpplans, addedCNodes, addedHops); + } + } + memo.add(hop.getHopID()); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java new file mode 100644 index 0000000..fd8a960 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.template; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.Map.Entry; + +import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.DataGenOp; +import org.apache.sysml.hops.DataOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.ReOrgOp; +import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; +import org.apache.sysml.hops.codegen.cplan.CNodeCell; +import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; +import org.apache.sysml.runtime.matrix.data.Pair; +import org.apache.sysml.runtime.util.UtilFunctions; + +public class TemplateUtils +{ + public static boolean inputsAreGenerated(Hop parent, ArrayList<Hop> inputs, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans) + { + if( parent instanceof DataOp || parent instanceof DataGenOp || parent instanceof LiteralOp || inputs.contains(parent) ) + return false; + + for(Hop hop : parent.getInput() ) + if(!inputs.contains(hop) && !(hop instanceof DataOp) && !(hop instanceof DataGenOp) && !(hop.getDataType()==DataType.SCALAR) && !isVector(hop) && !(cpplans.containsKey(hop.getHopID())) && !( hop instanceof ReorgOp && ((ReorgOp)hop).getOp() == ReOrgOp.TRANSPOSE && inputsAreGenerated(hop,inputs, cpplans) )) + return false; + return true; + } + + public static ArrayList<CNode> fetchOperands(Hop hop, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans, ArrayList<CNode> addedCNodes, ArrayList<Hop> addedHops, ArrayList<CNodeData> initialCNodes, boolean compileLiterals) + { + ArrayList<CNode> cnodeData = new ArrayList<CNode>(); + for (Hop h: hop.getInput()) + { + CNode cdata = null; + + //CNodeData already in template inputs + for(CNodeData c : initialCNodes) { + if( c.getHopID() == h.getHopID() ) { + cdata = c; + break; + } + } + + if(cdata != null) + { + cnodeData.add(cdata); + continue; + } + //hop already in the cplan + else if(cpplans.containsKey(h.getHopID())) + { + cdata = cpplans.get(h.getHopID()).getValue().getOutput(); + } + else if(h instanceof ReorgOp && ((ReorgOp)h).getOp()==ReOrgOp.TRANSPOSE ) + { + //fetch what is under the transpose + Hop in = h.getInput().get(0); + cdata = new CNodeData(in); + if(in instanceof DataOp || in instanceof DataGenOp ) { + addedCNodes.add(cdata); + addedHops.add(in); + } + } + else + { + //note: only compile literals if forced or integer literals (likely constants) + //to increase reuse potential on literal replacement during recompilation + cdata = new CNodeData(h); + cdata.setLiteral(h instanceof LiteralOp && (compileLiterals + || UtilFunctions.isIntegerNumber(((LiteralOp)h).getStringValue()))); + if( !cdata.isLiteral() ) { + addedCNodes.add(cdata); + addedHops.add(h); + } + } + + cnodeData.add(cdata); + } + return cnodeData; + } + + public static void setOutputToExistingTemplate(Hop hop, CNode out, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans, ArrayList<CNode> addedCNodes, ArrayList<Hop> addedHops) + { + //get the toplevel rowTemp + Entry<Long, Pair<Hop[],CNodeTpl>> cplan = null; + Iterator<Entry<Long, Pair<Hop[],CNodeTpl>>> iterator = cpplans.entrySet().iterator(); + while (iterator.hasNext()) + cplan = iterator.next(); + + CNodeTpl tmpl = cplan.getValue().getValue().clone(); + tmpl.setDataType(hop.getDataType()); + + if(tmpl instanceof CNodeOuterProduct) { + ((CNodeOuterProduct) tmpl).setOutProdType( ((CNodeOuterProduct)cplan.getValue().getValue()).getOutProdType()); + ((CNodeOuterProduct) tmpl).setTransposeOutput(((CNodeOuterProduct)cplan.getValue().getValue()).isTransposeOutput() ); + } + else if( tmpl instanceof CNodeCell ) { + ((CNodeCell)tmpl).setCellType(getCellType(hop)); + ((CNodeCell)tmpl).setMultipleConsumers(hop.getParent().size()>1); + } + + //add extra inputs + for(CNode c : addedCNodes) + tmpl.addInput(c); + + //modify addedHops if they exist + + Hop[] currentInputHops = cplan.getValue().getKey(); + for (Hop h : currentInputHops) + if (addedHops.contains(h)) + addedHops.remove(h); + + Hop[] extendedHopInputs = new Hop[cplan.getValue().getKey().length + addedHops.size()]; + System.arraycopy(cplan.getValue().getKey(), 0, extendedHopInputs, 0, cplan.getValue().getKey().length); + for(int j=addedHops.size(); j > 0; j--) + extendedHopInputs[extendedHopInputs.length-j] = addedHops.get(addedHops.size() - j); //append the added hops to the end of the array + + //set the template output and add it to the cpplans + Pair<Hop[],CNodeTpl> pair = new Pair<Hop[],CNodeTpl>(extendedHopInputs,tmpl); + pair.getValue().setOutput(out); + cpplans.put(hop.getHopID(), pair); + + } + + public static boolean isOperandsIndependent(ArrayList<CNode> cnodeData, ArrayList<Hop> addedHops, String[] varNames) + { + for(CNode c : cnodeData) { + // it is some variable inside the cplan // TODO needs to be modified because sometimes the varname is not null but the variable is in the cplan + if(c.getVarname() == null) + return false; + //if one of the operands is is any of the varnames // if one of the operands is T(X) this condition will apply as well because during fetch operands we fetch what is inside transpose + for(String varName : varNames) + if(c.getVarname().equals(varName)) + return false; + } + return true; + } + + public static Entry<Long, Pair<Hop[],CNodeTpl>> getTopLevelCpplan(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) + { + Entry<Long, Pair<Hop[],CNodeTpl>> ret = null; + + //get last entry (most fused operators) or special handling + boolean hasExp = false; + for( Entry<Long, Pair<Hop[],CNodeTpl>> e : cplans.entrySet() ) + { + ret = e; //keep last seen entry + + //special handling overlapping fused operators with exp + hasExp |= (ret.getValue().getValue().getOutput() instanceof CNodeUnary + && ((CNodeUnary)ret.getValue().getValue().getOutput()).getType()==UnaryType.EXP); + + if( hasExp && ret.getValue().getValue() instanceof CNodeCell + && ((CNodeCell)ret.getValue().getValue()).hasMultipleConsumers() ) + break; + } + + return ret; + } + + public static boolean isVector(Hop hop) { + return (hop.getDataType() == DataType.MATRIX + && (hop.getDim1() != 1 && hop.getDim2() == 1 + || hop.getDim1() == 1 && hop.getDim2() != 1 ) ); + } + + public static boolean isColVector(CNode hop) { + return (hop.getDataType() == DataType.MATRIX + && hop.getNumRows() != 1 && hop.getNumCols() == 1); + } + + public static boolean isRowVector(CNode hop) { + return (hop.getDataType() == DataType.MATRIX + && hop.getNumRows() == 1 && hop.getNumCols() != 1); + } + + public static boolean isMatrix(Hop hop) { + return (hop.getDataType() == DataType.MATRIX && hop.getDim1() != 1 && hop.getDim2()!=1); + } + + public static boolean isVectorOrScalar(Hop hop) { + return hop.dimsKnown() && (hop.getDataType() == DataType.SCALAR || isVector(hop) ); + } + + public static boolean isBinaryMatrixRowVector(Hop hop) { + if( !(hop instanceof BinaryOp) ) + return false; + Hop left = hop.getInput().get(0); + Hop right = hop.getInput().get(1); + return left.dimsKnown() && right.dimsKnown() + && left.getDataType().isMatrix() && right.getDataType().isMatrix() + && left.getDim1() > right.getDim1(); + } + + public static boolean isOperationSupported(Hop h) { + if(h instanceof UnaryOp) + return UnaryType.contains(((UnaryOp)h).getOp().toString()); + else if(h instanceof BinaryOp) + return BinType.contains(((BinaryOp)h).getOp().toString()); + else + return false; + } + + private static void rfindChildren(Hop hop, HashSet<Hop> children ) { + if( hop instanceof UnaryOp || (hop instanceof BinaryOp && hop.getInput().get(0).getDataType() == DataType.MATRIX && TemplateUtils.isVectorOrScalar( hop.getInput().get(1))) || (hop instanceof BinaryOp && TemplateUtils.isVectorOrScalar( hop.getInput().get(0)) && hop.getInput().get(1).getDataType() == DataType.MATRIX) //unary operation or binary operaiton with one matrix and a scalar + && hop.getDataType() == DataType.MATRIX ) + { + if(!children.contains(hop)) + children.add(hop); + Hop matrix = TemplateUtils.isMatrix(hop.getInput().get(0)) ? hop.getInput().get(0) : hop.getInput().get(1); + rfindChildren(matrix,children); + } + else + children.add(hop); + } + + private static Hop findCommonChild(Hop hop1, Hop hop2) { + //this method assumes that each two nodes have at most one common child + LinkedHashSet<Hop> children1 = new LinkedHashSet<Hop>(); + LinkedHashSet<Hop> children2 = new LinkedHashSet<Hop>(); + + rfindChildren(hop1, children1 ); + rfindChildren(hop2, children2 ); + + //iterate on one set and find the first common child in the other set + Iterator<Hop> iter = children1.iterator(); + while (iter.hasNext()) { + Hop candidate = iter.next(); + if(children2.contains(candidate)) + return candidate; + } + return null; + } + + public static Hop commonChild(ArrayList<Hop> _adddedMatrices, Hop input) { + Hop currentChild = null; + //loop on every added matrix and find its common child with the input, if all of them have the same common child then return it, otherwise null + for(Hop addedMatrix : _adddedMatrices) + { + Hop child = findCommonChild(addedMatrix,input); + if(child == null) // did not find a common child + return null; + if(currentChild == null) // first common child to be seen + currentChild = child; + else if(child.getHopID() != currentChild.getHopID()) + return null; + } + return currentChild; + } + + public static HashSet<Long> rGetInputHopIDs( CNode node, HashSet<Long> ids ) { + if( node instanceof CNodeData && !node.isLiteral() ) + ids.add(((CNodeData)node).getHopID()); + + for( CNode c : node.getInput() ) + rGetInputHopIDs(c, ids); + + return ids; + } + + public static Hop[] mergeDistinct(HashSet<Long> ids, Hop[] input1, Hop[] input2) { + Hop[] ret = new Hop[ids.size()]; + int pos = 0; + for( Hop[] input : new Hop[][]{input1, input2} ) + for( Hop c : input ) + if( ids.contains(c.getHopID()) ) + ret[pos++] = c; + return ret; + } + + private static CellType getCellType(Hop hop) { + return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM) ? + ((((AggUnaryOp) hop).getDirection() == Direction.RowCol) ? + CellType.FULL_AGG : CellType.ROW_AGG) : CellType.NO_AGG; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 7f65ddd..802a382 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -253,6 +253,12 @@ public class HopRewriteUtils child.getParent().add( parent ); } + public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) { + ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent()); + for( Hop lparent : parents ) + HopRewriteUtils.replaceChildReference(lparent, hold, hnew); + } + public static void replaceChildReference( Hop parent, Hop inOld, Hop inNew ) { int pos = getChildReferencePos(parent, inOld); removeChildReferenceByPos(parent, inOld, pos); @@ -491,10 +497,12 @@ public class HopRewriteUtils input2.getDataType().isMatrix() ? input2 : input1; BinaryOp bop = new BinaryOp(mainInput.getName(), mainInput.getDataType(), mainInput.getValueType(), op, input1, input2); + //cleanup value type for relational operations + if( bop.isPPredOperation() && bop.getDataType().isScalar() ) + bop.setValueType(ValueType.BOOLEAN); bop.setOutputBlocksizes(mainInput.getRowsInBlock(), mainInput.getColsInBlock()); copyLineNumbers(mainInput, bop); bop.refreshSizeInformation(); - return bop; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java index 558deb3..cea2c93 100644 --- a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java +++ b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java @@ -166,48 +166,21 @@ public class ConvolutionTransform extends Lop } } - // Used by maxpool - public String getInstructions(String input, String stride1, String stride2, String padding1, String padding2, - String input_shape1, String input_shape2, String input_shape3, String input_shape4, - String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, - String output) throws LopsException { - StringBuilder sb = new StringBuilder(); - appendOpcode(sb); - sb.append( getInputs().get(0).prepInputOperand(input)); - appendOperands(1, 13, output, sb); - return sb.toString(); - } - - // Used by conv2d*, maxpool_bwd - public String getInstructions(String input, String dout, String stride1, String stride2, String padding1, String padding2, - String input_shape1, String input_shape2, String input_shape3, String input_shape4, - String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, - String output) throws LopsException { - StringBuilder sb = new StringBuilder(); - appendOpcode(sb); - sb.append( getInputs().get(0).prepInputOperand(input)); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(1).prepInputOperand(dout)); - appendOperands(2, 14, output, sb); - return sb.toString(); - } - - // Used by fused conv2d+bias_add - public String getInstructions(String input, String bias, String filter, String stride1, String stride2, String padding1, String padding2, - String input_shape1, String input_shape2, String input_shape3, String input_shape4, - String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, - String output) throws LopsException { + @Override + public String getInstructions(String[] inputs, String output) throws LopsException { StringBuilder sb = new StringBuilder(); appendOpcode(sb); - sb.append( getInputs().get(0).prepInputOperand(input)); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(1).prepInputOperand(bias)); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(2).prepInputOperand(filter)); - appendOperands(3, 15, output, sb); + + for( int i=0; i<inputs.length-12; i++ ) { + if( i > 0 ) + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(i).prepInputOperand(inputs[i])); + } + appendOperands(inputs.length-12, inputs.length, output, sb); + return sb.toString(); } - + public void appendOpcode(StringBuilder sb) { sb.append( getExecType() ); sb.append( OPERAND_DELIMITOR ); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/lops/Lop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java index 567b0be..24f7ba3 100644 --- a/src/main/java/org/apache/sysml/lops/Lop.java +++ b/src/main/java/org/apache/sysml/lops/Lop.java @@ -59,6 +59,7 @@ public abstract class Lop SortKeys, PickValues, Checkpoint, //Spark persist into storage level PlusMult, MinusMult, //CP + SpoofFused, //CP/SP generated fused operator /** CP operation on a variable number of operands */ MULTIPLE_CP }; @@ -418,6 +419,40 @@ public abstract class Lop return outParams; } + + /** Method should be overridden if needed + * + * @param output output + * @return instructions as string + * @throws LopsException if LopsException occurs + */ + public String getInstructions(String output) throws LopsException { + throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); + } + + /** Method should be overridden if needed + * + * @param input1 input 1 + * @param output output + * @return instructions as string + * @throws LopsException if LopsException occurs + */ + public String getInstructions(String input1, String output) throws LopsException { + throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); + } + + /** Method should be overridden if needed + * + * @param input1 input 1 + * @param input2 input 2 + * @param output output + * @return instructions as string + * @throws LopsException if LopsException occurs + */ + public String getInstructions(String input1, String input2, String output) throws LopsException { + throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); + } + /** * Method should be overridden if needed * @@ -478,6 +513,15 @@ public abstract class Lop public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String output) throws LopsException { throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); } + + public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String input7, String output) throws LopsException { + throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); + } + + public String getInstructions(String[] inputs, String outputs) throws LopsException { + throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); + } + public String getInstructions(int output_index) throws LopsException { throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass. Lop Type: " + this.getType()); @@ -541,38 +585,6 @@ public abstract class Lop throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); } - /** Method should be overridden if needed - * - * @param input1 input 1 - * @param input2 input 2 - * @param output output - * @return instructions as string - * @throws LopsException if LopsException occurs - */ - public String getInstructions(String input1, String input2, String output) throws LopsException { - throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); - } - - /** Method should be overridden if needed - * - * @param input1 input 1 - * @param output output - * @return instructions as string - * @throws LopsException if LopsException occurs - */ - public String getInstructions(String input1, String output) throws LopsException { - throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); - } - - /** Method should be overridden if needed - * - * @param output output - * @return instructions as string - * @throws LopsException if LopsException occurs - */ - public String getInstructions(String output) throws LopsException { - throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); - } /** Method should be overridden if needed * @@ -630,37 +642,6 @@ public abstract class Lop return "ERROR: line " + _beginLine + ", column " + _beginColumn + " -- "; } - //TODO: Leo This might get confused with Rand.getInstructions - public String getInstructions(String input, String rowl, String rowu, - String coll, String colu, String leftRowDim, - String leftColDim, String output) throws LopsException { - throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); - } - - // stride1, stride2, padding1, padding2 - // input_shape1, input_shape2, input_shape3, input_shape4, - // filter_shape1, filter_shape2, filter_shape3, filter_shape4, - public String getInstructions(String input, String stride1, String stride2, String padding1, String padding2, - String input_shape1, String input_shape2, String input_shape3, String input_shape4, - String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, - String output) throws LopsException { - throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); - } - - public String getInstructions(String input, String dout, String stride1, String stride2, String padding1, String padding2, - String input_shape1, String input_shape2, String input_shape3, String input_shape4, - String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, - String output) throws LopsException { - throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); - } - - public String getInstructions(String input, String bias, String dout, String stride1, String stride2, String padding1, String padding2, - String input_shape1, String input_shape2, String input_shape3, String input_shape4, - String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, - String output) throws LopsException { - throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); - } - public String getInstructions(int input, int rowl, int rowu, int coll, int colu, int leftRowDim, int leftColDim, int output) throws LopsException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/lops/SpoofFused.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/SpoofFused.java b/src/main/java/org/apache/sysml/lops/SpoofFused.java new file mode 100644 index 0000000..3f0ec59 --- /dev/null +++ b/src/main/java/org/apache/sysml/lops/SpoofFused.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.lops; + +import java.util.ArrayList; + +import org.apache.sysml.lops.LopProperties.ExecLocation; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.compile.JobType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; + +public class SpoofFused extends Lop +{ + private final Class<?> _class; + private final int _numThreads; + + public SpoofFused( ArrayList<Lop> inputs, DataType dt, ValueType vt, Class<?> cla, int k, ExecType etype) { + super(Type.SpoofFused, dt, vt); + _class = cla; + _numThreads = k; + + for( Lop lop : inputs ) { + addInput(lop); + lop.addOutput(this); + } + + lps.addCompatibility(JobType.INVALID); + lps.setProperties( inputs, etype, ExecLocation.ControlProgram, false, false, false ); + } + + @Override + public String toString() { + return "spoof("+_class.getSimpleName()+")"; + } + + @Override + public String getInstructions(String input1, String output) throws LopsException { + return getInstructions(new String[]{input1}, new String[]{output}); + } + + @Override + public String getInstructions(String input1, String input2, String output) throws LopsException { + return getInstructions(new String[]{input1, input2}, new String[]{output}); + } + + @Override + public String getInstructions(String input1, String input2, String input3, String output) throws LopsException { + return getInstructions(new String[]{input1, input2, input3}, new String[]{output}); + } + + @Override + public String getInstructions(String input1, String input2, String input3, String input4, String output) throws LopsException { + return getInstructions(new String[]{input1, input2, input3, input4}, new String[]{output}); + } + + @Override + public String getInstructions(String input1, String input2, String input3, String input4, String input5, String output) throws LopsException { + return getInstructions(new String[]{input1, input2, input3, input4, input5}, new String[]{output}); + } + + @Override + public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String output) throws LopsException { + return getInstructions(new String[]{input1, input2, input3, input4, input5, input6}, new String[]{output}); + } + + @Override + public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String input7, String output) throws LopsException { + return getInstructions(new String[]{input1, input2, input3, input4, input5, input6, input7}, new String[]{output}); + } + + @Override + public String getInstructions(String[] inputs, String output) throws LopsException { + return getInstructions(inputs, new String[]{output}); + } + + @Override + public String getInstructions(String[] inputs, String[] outputs) + throws LopsException + { + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + sb.append( OPERAND_DELIMITOR ); + sb.append( "spoof" ); + + sb.append( OPERAND_DELIMITOR ); + sb.append( _class.getName() ); + + for(int i=0; i < inputs.length; i++) { + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(i).prepInputOperand(inputs[i])); + } + + sb.append( OPERAND_DELIMITOR ); + sb.append( prepOutputOperand(outputs[0]) ); + + sb.append( OPERAND_DELIMITOR ); + sb.append( _numThreads ); + + return sb.toString(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/lops/compile/Dag.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java index 898f4ec..b513951 100644 --- a/src/main/java/org/apache/sysml/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java @@ -1491,65 +1491,12 @@ public class Dag<N extends Lop> node.getInputs().get(6).getOutputParameters().getLabel(), node.getOutputParameters().getLabel()); } - else if (node.getInputs().size() == 13) { - // Used for im2col and reshape_col - inst_string = node.getInstructions( - node.getInputs().get(0).getOutputParameters().getLabel(), - node.getInputs().get(1).getOutputParameters().getLabel(), - node.getInputs().get(2).getOutputParameters().getLabel(), - node.getInputs().get(3).getOutputParameters().getLabel(), - node.getInputs().get(4).getOutputParameters().getLabel(), - node.getInputs().get(5).getOutputParameters().getLabel(), - node.getInputs().get(6).getOutputParameters().getLabel(), - node.getInputs().get(7).getOutputParameters().getLabel(), - node.getInputs().get(8).getOutputParameters().getLabel(), - node.getInputs().get(9).getOutputParameters().getLabel(), - node.getInputs().get(10).getOutputParameters().getLabel(), - node.getInputs().get(11).getOutputParameters().getLabel(), - node.getInputs().get(12).getOutputParameters().getLabel(), - node.getOutputParameters().getLabel()); - } - else if (node.getInputs().size() == 14) { - // Used for pooling_backward - inst_string = node.getInstructions( - node.getInputs().get(0).getOutputParameters().getLabel(), - node.getInputs().get(1).getOutputParameters().getLabel(), - node.getInputs().get(2).getOutputParameters().getLabel(), - node.getInputs().get(3).getOutputParameters().getLabel(), - node.getInputs().get(4).getOutputParameters().getLabel(), - node.getInputs().get(5).getOutputParameters().getLabel(), - node.getInputs().get(6).getOutputParameters().getLabel(), - node.getInputs().get(7).getOutputParameters().getLabel(), - node.getInputs().get(8).getOutputParameters().getLabel(), - node.getInputs().get(9).getOutputParameters().getLabel(), - node.getInputs().get(10).getOutputParameters().getLabel(), - node.getInputs().get(11).getOutputParameters().getLabel(), - node.getInputs().get(12).getOutputParameters().getLabel(), - node.getInputs().get(13).getOutputParameters().getLabel(), - node.getOutputParameters().getLabel()); - } - else if (node.getInputs().size() == 15) { - // Used for fused conv2d_bias_add - inst_string = node.getInstructions( - node.getInputs().get(0).getOutputParameters().getLabel(), - node.getInputs().get(1).getOutputParameters().getLabel(), - node.getInputs().get(2).getOutputParameters().getLabel(), - node.getInputs().get(3).getOutputParameters().getLabel(), - node.getInputs().get(4).getOutputParameters().getLabel(), - node.getInputs().get(5).getOutputParameters().getLabel(), - node.getInputs().get(6).getOutputParameters().getLabel(), - node.getInputs().get(7).getOutputParameters().getLabel(), - node.getInputs().get(8).getOutputParameters().getLabel(), - node.getInputs().get(9).getOutputParameters().getLabel(), - node.getInputs().get(10).getOutputParameters().getLabel(), - node.getInputs().get(11).getOutputParameters().getLabel(), - node.getInputs().get(12).getOutputParameters().getLabel(), - node.getInputs().get(13).getOutputParameters().getLabel(), - node.getInputs().get(14).getOutputParameters().getLabel(), - node.getOutputParameters().getLabel()); - } else { - throw new LopsException(node.printErrorLocation() + "Node with " + node.getInputs().size() + " inputs is not supported in CP yet! \n"); + String[] inputs = new String[node.getInputs().size()]; + for( int j=0; j<node.getInputs().size(); j++ ) + inputs[j] = node.getInputs().get(j).getOutputParameters().getLabel(); + inst_string = node.getInstructions(inputs, + node.getOutputParameters().getLabel()); } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/runtime/codegen/ByteClassLoader.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/ByteClassLoader.java b/src/main/java/org/apache/sysml/runtime/codegen/ByteClassLoader.java new file mode 100644 index 0000000..27263d3 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/codegen/ByteClassLoader.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.codegen; + +import java.net.URL; +import java.net.URLClassLoader; + +public class ByteClassLoader extends URLClassLoader +{ + private final byte[] _classBytes; + + public ByteClassLoader(URL[] urls, ClassLoader parent, byte[] classBytes) { + super(urls, parent); + _classBytes = classBytes; + } + + @Override + public Class<?> findClass(String className) throws ClassNotFoundException { + if (_classBytes != null) + return defineClass(className, _classBytes, 0, _classBytes.length); + return super.loadClass(className); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/runtime/codegen/CodegenUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/CodegenUtils.java b/src/main/java/org/apache/sysml/runtime/codegen/CodegenUtils.java new file mode 100644 index 0000000..fdad9bd --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/codegen/CodegenUtils.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.codegen; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectOutputStream; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +import javax.tools.Diagnostic; +import javax.tools.Diagnostic.Kind; +import javax.tools.DiagnosticCollector; +import javax.tools.JavaCompiler; +import javax.tools.JavaCompiler.CompilationTask; +import javax.tools.JavaFileObject; +import javax.tools.StandardJavaFileManager; +import javax.tools.ToolProvider; + +import org.apache.commons.io.IOUtils; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.io.IOUtilFunctions; +import org.apache.sysml.runtime.util.LocalFileUtils; +import org.apache.sysml.utils.Statistics; + +public class CodegenUtils +{ + //cache to reuse compiled and loaded classes (this is also a workaround for classes, + //compiled during initial compilation and subsequently loaded as the working directory + //is cleaned up just before the actual execution + private static ConcurrentHashMap<String, Class<?>> _cache = new ConcurrentHashMap<String,Class<?>>(); + private static String _workingDir = null; + + public static Class<?> compileClass(String name, String src) + throws DMLRuntimeException + { + //reuse existing compiled class + Class<?> ret = _cache.get(name); + if( ret != null ) + return ret; + + long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; + + try + { + //create working dir on demand + if( _workingDir == null ) + createWorkingDir(); + + //write input file (for debugging / classpath handling) + File ftmp = new File(_workingDir+"/codegen/"+name+".java"); + if( !ftmp.getParentFile().exists() ) + ftmp.getParentFile().mkdirs(); + LocalFileUtils.writeTextFile(ftmp, src); + + //get system java compiler + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + if( compiler == null ) + throw new RuntimeException("Unable to obtain system java compiler."); + + //prepare file manager + DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<JavaFileObject>(); + StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnostics, null, null); + + //prepare input source code + Iterable<? extends JavaFileObject> sources = fileManager + .getJavaFileObjectsFromFiles(Arrays.asList(ftmp)); + + //prepare class path + URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation(); + String classpath = System.getProperty("java.class.path") + + File.pathSeparator + runDir.getPath(); + List<String> options = Arrays.asList("-classpath",classpath); + + //compile source code + CompilationTask task = compiler.getTask(null, fileManager, diagnostics, options, null, sources); + Boolean success = task.call(); + + //output diagnostics and error handling + for(Diagnostic<? extends JavaFileObject> tmp : diagnostics.getDiagnostics()) + if( tmp.getKind()==Kind.ERROR ) + System.err.println("ERROR: "+tmp.toString()); + if( success == null || !success ) + throw new RuntimeException("Failed to compile class "+name); + + //dynamically load compiled class + URLClassLoader classLoader = new URLClassLoader( + new URL[]{new File(_workingDir).toURI().toURL(), runDir}, + CodegenUtils.class.getClassLoader()); + ret = classLoader.loadClass("codegen."+name); + classLoader.close(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + + //keep compiled class for reuse + _cache.put(name, ret); + + if( DMLScript.STATISTICS ) { + Statistics.incrementCodegenClassCompile(); + Statistics.incrementCodegenClassCompileTime(System.nanoTime()-t0); + } + + return ret; + } + + public static Class<?> loadClass(String name, byte[] classBytes) throws DMLRuntimeException { + //reuse existing compiled class + Class<?> ret = _cache.get(name); + if( ret != null ) + return ret; + + //define class using the bytes + if(classBytes != null) + { + //ByteClassLoader byteLoader = new ByteClassLoader(classLoader.getURLs() , classLoader.getParent(), classBytes); + try { + ByteClassLoader byteLoader = new ByteClassLoader(new URL[]{} ,CodegenUtils.class.getClassLoader(), classBytes); + ret = byteLoader.findClass(name); + byteLoader.close(); + } catch (Exception e) { + throw new DMLRuntimeException(e); + } + } + else + { + //dynamically load compiled class + URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation(); + URLClassLoader classLoader = null; + try { + classLoader = new URLClassLoader( + new URL[]{new File(_workingDir).toURI().toURL(), runDir}, + CodegenUtils.class.getClassLoader()); + ret = classLoader.loadClass(name); + } + catch (Exception e) { + throw new DMLRuntimeException(e); + } + finally { + IOUtilFunctions.closeSilently(classLoader); + } + } + + //keep loaded class for reuse + _cache.put(name, ret); + return ret; + } + + public static Object createInstance(Class<?> cla) + throws DMLRuntimeException + { + Object ret = null; + + try { + ret = cla.newInstance(); + } + catch( Exception ex ) { + throw new DMLRuntimeException(ex); + } + + return ret; + } + + public static byte[] getClassAsByteArray(String name) + throws DMLRuntimeException + { + //reuse existing compiled class + Class<?> cls = _cache.get(name); + if( cls != null ) + return getClassAsByteArray(cls); + + + String classAsPath = name.replace('.', '/') + ".class"; + + URLClassLoader classLoader = null; + byte[] ret = null; + + try { + //dynamically load compiled class + URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation(); + classLoader = new URLClassLoader( + new URL[]{new File(_workingDir).toURI().toURL(), runDir}, + CodegenUtils.class.getClassLoader()); + InputStream stream = classLoader.getResourceAsStream(classAsPath); + ret = IOUtils.toByteArray(stream); + } + catch (IOException e) { + throw new DMLRuntimeException(e); + } + finally { + IOUtilFunctions.closeSilently(classLoader); + } + + return ret; + } + + + public static byte[] getClassAsByteArray(Class<?> cls) + throws DMLRuntimeException + { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try { + ObjectOutputStream oos = new ObjectOutputStream(bos); + oos.writeObject(cls); + oos.flush(); + return bos.toByteArray(); + } + catch( IOException e ) { + throw new DMLRuntimeException(e); + } + finally { + IOUtilFunctions.closeSilently(bos); + } + } + + private static void createWorkingDir() throws DMLRuntimeException { + if( _workingDir != null ) + return; + String tmp = LocalFileUtils.getWorkingDir(LocalFileUtils.CATEGORY_CODEGEN); + LocalFileUtils.createLocalFileIfNotExist(tmp); + _workingDir = tmp; + } + + public static URL[] getUrls() throws DMLRuntimeException { + try { + URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation(); + return new URL[]{new File(_workingDir).toURI().toURL(), runDir}; + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + } + + public static String getSpoofType(Class<?> cls) { + if(cls.getSuperclass() == SpoofCellwise.class) + return "Cell" + cls.getName().split("\\.")[1]; + else if(cls.getSuperclass() == SpoofOuterProduct.class) + return "OP" + cls.getName().split("\\.")[1]; + else if(cls.getSuperclass() == SpoofRowAggregate.class) + return "RA" + cls.getName().split("\\.")[1]; + else + return "UNKNOWN"; + } +}
