[SYSTEMML-1285] New basic code generator for operator fusion This patch introduces a cleaned-up version of SPOOF's basic code generator, covering its core compiler and runtime operators as well as its basic integration into the stats and explain tools (SYSTEMML-1296 and SYSTEMML-1297).
Furthermore, this also includes the following minor fixes and improvements of existing components: * Fix of rewrite utils for creating binary scalar operations with boolean outputs * Cleanup instruction generation convolution lop * Fix lop dag compilation (removed constraint of max 7 input lops) * Improved value type handling of scalar comparison instructions * Fix various gpu-related src and javadoc warnings Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d7fd5879 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d7fd5879 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d7fd5879 Branch: refs/heads/master Commit: d7fd58795c06dea8db6fb55a045a8b312547f398 Parents: b78c125 Author: Matthias Boehm <[email protected]> Authored: Sun Feb 26 18:53:46 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Sun Feb 26 18:53:46 2017 -0800 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 5 + .../sysml/hops/codegen/SpoofCompiler.java | 407 ++++++++++++++ .../apache/sysml/hops/codegen/SpoofFusedOp.java | 212 ++++++++ .../apache/sysml/hops/codegen/cplan/CNode.java | 167 ++++++ .../sysml/hops/codegen/cplan/CNodeBinary.java | 260 +++++++++ .../sysml/hops/codegen/cplan/CNodeCell.java | 144 +++++ .../sysml/hops/codegen/cplan/CNodeData.java | 94 ++++ .../hops/codegen/cplan/CNodeOuterProduct.java | 165 ++++++ .../hops/codegen/cplan/CNodeRowAggVector.java | 111 ++++ .../sysml/hops/codegen/cplan/CNodeTpl.java | 201 +++++++ .../sysml/hops/codegen/cplan/CNodeUnary.java | 206 +++++++ .../sysml/hops/codegen/template/BaseTpl.java | 63 +++ .../sysml/hops/codegen/template/CellTpl.java | 289 ++++++++++ .../hops/codegen/template/CplanRegister.java | 168 ++++++ .../hops/codegen/template/OuterProductTpl.java | 489 +++++++++++++++++ .../sysml/hops/codegen/template/RowAggTpl.java | 321 +++++++++++ .../hops/codegen/template/TemplateUtils.java | 313 +++++++++++ .../sysml/hops/rewrite/HopRewriteUtils.java | 10 +- .../apache/sysml/lops/ConvolutionTransform.java | 49 +- src/main/java/org/apache/sysml/lops/Lop.java | 107 ++-- .../java/org/apache/sysml/lops/SpoofFused.java | 119 ++++ .../java/org/apache/sysml/lops/compile/Dag.java | 63 +-- .../sysml/runtime/codegen/ByteClassLoader.java | 40 ++ .../sysml/runtime/codegen/CodegenUtils.java | 268 +++++++++ .../runtime/codegen/LibSpoofPrimitives.java | 257 +++++++++ .../sysml/runtime/codegen/SpoofCellwise.java | 430 +++++++++++++++ .../sysml/runtime/codegen/SpoofOperator.java | 74 +++ .../runtime/codegen/SpoofOuterProduct.java | 541 +++++++++++++++++++ .../runtime/codegen/SpoofRowAggregate.java | 188 +++++++ .../controlprogram/parfor/util/IDSequence.java | 21 +- .../cp/RelationalBinaryCPInstruction.java | 52 +- .../cp/ScalarScalarRelationalCPInstruction.java | 22 +- .../instructions/gpu/context/GPUContext.java | 2 + .../instructions/gpu/context/GPUObject.java | 1 + .../instructions/gpu/context/JCudaObject.java | 2 + .../runtime/matrix/data/LibMatrixMult.java | 18 +- .../sysml/runtime/util/LocalFileUtils.java | 24 + .../java/org/apache/sysml/utils/Statistics.java | 73 +++ 38 files changed, 5742 insertions(+), 234 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 3aa3dab..4021a1a 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -789,6 +789,11 @@ public abstract class Hop public ArrayList<Hop> getInput() { return _input; } + + public void addInput( Hop h ) { + _input.add(h); + h._parent.add(this); + } public long getRowsInBlock() { return _rows_in_block; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java new file mode 100644 index 0000000..dd24703 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map.Entry; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysml.api.DMLException; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeCell; +import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary; +import org.apache.sysml.hops.codegen.template.BaseTpl; +import org.apache.sysml.hops.codegen.template.CellTpl; +import org.apache.sysml.hops.codegen.template.CplanRegister; +import org.apache.sysml.hops.codegen.template.OuterProductTpl; +import org.apache.sysml.hops.codegen.template.RowAggTpl; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.ForStatement; +import org.apache.sysml.parser.ForStatementBlock; +import org.apache.sysml.parser.FunctionStatement; +import org.apache.sysml.parser.FunctionStatementBlock; +import org.apache.sysml.parser.IfStatement; +import org.apache.sysml.parser.IfStatementBlock; +import org.apache.sysml.parser.LanguageException; +import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.parser.WhileStatement; +import org.apache.sysml.parser.WhileStatementBlock; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.codegen.CodegenUtils; +import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; +import org.apache.sysml.runtime.matrix.data.Pair; +import org.apache.sysml.utils.Explain; +import org.apache.sysml.utils.Explain.ExplainType; +import org.apache.sysml.utils.Statistics; + +public class SpoofCompiler +{ + private static final Log LOG = LogFactory.getLog(SpoofCompiler.class.getName()); + + public static boolean OPTIMIZE = true; + + //internal configuration flags + public static final boolean LDEBUG = false; + public static final boolean SUM_PRODUCT = false; + public static final boolean RECOMPILE = true; + public static boolean USE_PLAN_CACHE = true; + public static boolean ALWAYS_COMPILE_LITERALS = false; + public static final boolean ALLOW_SPARK_OPS = false; + + //plan cache for cplan->compiled source to avoid unnecessary codegen/source code compile + //for equal operators from (1) different hop dags and (2) repeated recompilation + private static ConcurrentHashMap<CNode, Class<?>> planCache = new ConcurrentHashMap<CNode, Class<?>>(); + + public static void generateCode(DMLProgram dmlp) + throws LanguageException, HopsException, DMLRuntimeException + { + // cleanup static plan cache + planCache.clear(); + + // for each namespace, handle function statement blocks + for (String namespaceKey : dmlp.getNamespaces().keySet()) { + for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { + FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); + generateCodeFromStatementBlock(fsblock); + } + } + + // handle regular statement blocks in "main" method + for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { + StatementBlock current = dmlp.getStatementBlock(i); + generateCodeFromStatementBlock(current); + } + } + + public static void generateCodeFromStatementBlock(StatementBlock current) + throws HopsException, DMLRuntimeException + { + if (current instanceof FunctionStatementBlock) + { + FunctionStatementBlock fsb = (FunctionStatementBlock)current; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for (StatementBlock sb : fstmt.getBody()) + generateCodeFromStatementBlock(sb); + } + else if (current instanceof WhileStatementBlock) + { + WhileStatementBlock wsb = (WhileStatementBlock) current; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + wsb.setPredicateHops(optimize(wsb.getPredicateHops(), true)); + for (StatementBlock sb : wstmt.getBody()) + generateCodeFromStatementBlock(sb); + } + else if (current instanceof IfStatementBlock) + { + IfStatementBlock isb = (IfStatementBlock) current; + IfStatement istmt = (IfStatement)isb.getStatement(0); + isb.setPredicateHops(optimize(isb.getPredicateHops(), true)); + for (StatementBlock sb : istmt.getIfBody()) + generateCodeFromStatementBlock(sb); + for (StatementBlock sb : istmt.getElseBody()) + generateCodeFromStatementBlock(sb); + } + else if (current instanceof ForStatementBlock) //incl parfor + { + ForStatementBlock fsb = (ForStatementBlock) current; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + fsb.setFromHops(optimize(fsb.getFromHops(), true)); + fsb.setToHops(optimize(fsb.getToHops(), true)); + fsb.setIncrementHops(optimize(fsb.getIncrementHops(), true)); + for (StatementBlock sb : fstmt.getBody()) + generateCodeFromStatementBlock(sb); + } + else //generic (last-level) + { + current.set_hops( generateCodeFromHopDAGs(current.get_hops()) ); + current.updateRecompilationFlag(); + } + } + + public static ArrayList<Hop> generateCodeFromHopDAGs(ArrayList<Hop> roots) + throws HopsException, DMLRuntimeException + { + if( roots == null ) + return roots; + + ArrayList<Hop> optimized = SpoofCompiler.optimize(roots, true); + Hop.resetVisitStatus(roots); + Hop.resetVisitStatus(optimized); + + return optimized; + } + + + /** + * Main interface of sum-product optimizer, predicate dag. + * + * @param root dag root node + * @param compileLiterals if true literals compiled as constants, otherwise as scalar variables + * @return dag root node of modified dag + * @throws DMLRuntimeException if optimization failed + */ + public static Hop optimize( Hop root, boolean compileLiterals ) throws DMLRuntimeException { + if( root == null ) + return root; + + return optimize(new ArrayList<Hop>(Arrays.asList(root)), compileLiterals).get(0); + } + + /** + * Main interface of sum-product optimizer, statement block dag. + * + * @param roots dag root nodes + * @param compileLiterals if true literals compiled as constants, otherwise as scalar variables + * @return dag root nodes of modified dag + * @throws DMLRuntimeException if optimization failed + */ + @SuppressWarnings("unused") + public static ArrayList<Hop> optimize(ArrayList<Hop> roots, boolean compileLiterals) + throws DMLRuntimeException + { + if( roots == null || roots.isEmpty() || !OPTIMIZE ) + return roots; + + long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; + ArrayList<Hop> ret = roots; + + try + { + //construct codegen plans + HashMap<Long, Pair<Hop[],CNodeTpl>> cplans = constructCPlans(roots, compileLiterals); + + //cleanup codegen plans (remove unnecessary inputs, fix hop-cnodedata mapping, + //remove empty templates with single cnodedata input) + cplans = cleanupCPlans(cplans); + + //explain before modification + if( LDEBUG && cplans.size() > 0 ) { //existing cplans + LOG.info("Codegen EXPLAIN (before optimize): \n"+Explain.explainHops(roots)); + } + + //source code generation for all cplans + HashMap<Long, Pair<Hop[],Class<?>>> clas = new HashMap<Long, Pair<Hop[],Class<?>>>(); + for( Entry<Long, Pair<Hop[],CNodeTpl>> cplan : cplans.entrySet() ) { + Pair<Hop[],CNodeTpl> tmp = cplan.getValue(); + + if( !USE_PLAN_CACHE || !planCache.containsKey(tmp.getValue()) ) { + //generate java source code + String src = tmp.getValue().codegen(false); + + //explain debug output generated source code + if( LDEBUG || DMLScript.EXPLAIN != ExplainType.NONE ) { + LOG.info("Codegen EXPLAIN (generated code for HopID: " + cplan.getKey() +"):"); + LOG.info(src); + } + + //compile generated java source code + Class<?> cla = CodegenUtils.compileClass(tmp.getValue().getClassname(), src); + planCache.put(tmp.getValue(), cla); + } + else if( LDEBUG || DMLScript.STATISTICS ) { + Statistics.incrementCodegenPlanCacheHits(); + } + + Class<?> cla = planCache.get(tmp.getValue()); + if(cla != null) + clas.put(cplan.getKey(), new Pair<Hop[],Class<?>>(tmp.getKey(),cla)); + + if( LDEBUG || DMLScript.STATISTICS ) + Statistics.incrementCodegenPlanCacheTotal(); + } + + //generate final hop dag + ret = constructModifiedHopDag(roots, cplans, clas); + + //explain after modification + if( LDEBUG && cplans.size() > 0 ) { //existing cplans + LOG.info("Codegen EXPLAIN (after optimize): \n"+Explain.explainHops(roots)); + } + } + catch( Exception ex ) { + throw new DMLRuntimeException(ex); + } + + if( DMLScript.STATISTICS ) { + Statistics.incrementCodegenDAGCompile(); + Statistics.incrementCodegenCompileTime(System.nanoTime()-t0); + } + + return ret; + } + + + //////////////////// + // Codegen plan construction + + private static HashMap<Long, Pair<Hop[],CNodeTpl>> constructCPlans(ArrayList<Hop> roots, boolean compileLiterals) throws DMLException + { + LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> ret = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>(); + for( Hop hop : roots ) { + CplanRegister perRootCplans = new CplanRegister(); + HashSet<Long> memo = new HashSet<Long>(); + rConstructCPlans(hop, perRootCplans, memo, compileLiterals); + + for (Entry<Long, Pair<Hop[],CNodeTpl>> entry : perRootCplans.getTopLevelCplans().entrySet()) + if(!ret.containsKey(entry.getKey())) + ret.put(entry.getKey(), entry.getValue()); + } + return ret; + } + + private static void rConstructCPlans(Hop hop, CplanRegister cplanReg, HashSet<Long> memo, boolean compileLiterals) throws DMLException + { + if( memo.contains(hop.getHopID()) ) + return; + + //construct template instances + BaseTpl[] templates = new BaseTpl[]{ + new RowAggTpl(), new CellTpl(), new OuterProductTpl()}; + + //process hop with all templates + for( BaseTpl tpl : templates ) { + if( tpl.openTpl(hop) && tpl.findTplBoundaries(hop,cplanReg) ) { + cplanReg.insertCpplans(tpl.getType(), + tpl.constructTplCplan(compileLiterals)); + } + } + + //process childs recursively + memo.add(hop.getHopID()); + for( Hop c : hop.getInput() ) + rConstructCPlans(c, cplanReg, memo, compileLiterals); + } + + //////////////////// + // Codegen hop dag construction + + private static ArrayList<Hop> constructModifiedHopDag(ArrayList<Hop> orig, + HashMap<Long, Pair<Hop[],CNodeTpl>> cplans, HashMap<Long, Pair<Hop[],Class<?>>> cla) + { + HashSet<Long> memo = new HashSet<Long>(); + for( int i=0; i<orig.size(); i++ ) { + Hop hop = orig.get(i); //w/o iterator because modified + rConstructModifiedHopDag(hop, cplans, cla, memo); + } + return orig; + } + + private static void rConstructModifiedHopDag(Hop hop, HashMap<Long, Pair<Hop[],CNodeTpl>> cplans, + HashMap<Long, Pair<Hop[],Class<?>>> clas, HashSet<Long> memo) + { + if( memo.contains(hop.getHopID()) ) + return; //already processed + + Hop hnew = hop; + if( clas.containsKey(hop.getHopID()) ) + { + //replace sub-dag with generated operator + Pair<Hop[], Class<?>> tmpCla = clas.get(hop.getHopID()); + CNodeTpl tmpCNode = cplans.get(hop.getHopID()).getValue(); + hnew = new SpoofFusedOp(hop.getName(), hop.getDataType(), hop.getValueType(), + tmpCla.getValue(), false, tmpCNode.getOutputDimType()); + for( Hop in : tmpCla.getKey() ) { + hnew.addInput(in); //add inputs + } + hnew.setOutputBlocksizes(hop.getRowsInBlock() , hop.getColsInBlock()); + hnew.setDim1(hop.getDim1()); + hnew.setDim2(hop.getDim2()); + if(tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct)tmpCNode).isTransposeOutput() ) { + hnew = HopRewriteUtils.createTranspose(hnew); + } + + HopRewriteUtils.rewireAllParentChildReferences(hop, hnew); + memo.add(hnew.getHopID()); + } + + //process hops recursively (parent-child links modified) + for( int i=0; i<hnew.getInput().size(); i++ ) { + Hop c = hnew.getInput().get(i); + rConstructModifiedHopDag(c, cplans, clas, memo); + } + memo.add(hnew.getHopID()); + } + + /** + * Cleanup generated cplans in order to remove unnecessary inputs created + * during incremental construction. This is important as it avoids unnecessary + * redundant computation. + * + * @param cplans set of cplans + */ + private static HashMap<Long, Pair<Hop[],CNodeTpl>> cleanupCPlans(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) { + HashMap<Long, Pair<Hop[],CNodeTpl>> cplans2 = new HashMap<Long, Pair<Hop[],CNodeTpl>>(); + for( Entry<Long, Pair<Hop[],CNodeTpl>> e : cplans.entrySet() ) { + CNodeTpl tpl = e.getValue().getValue(); + Hop[] inHops = e.getValue().getKey(); + + //collect cplan leaf node names + HashSet<Long> leafs = new HashSet<Long>(); + rCollectLeafIDs(tpl.getOutput(), leafs); + + //create clean cplan w/ minimal inputs + if( inHops.length == leafs.size() ) + cplans2.put(e.getKey(), e.getValue()); + else { + tpl.cleanupInputs(leafs); + ArrayList<Hop> tmp = new ArrayList<Hop>(); + for( Hop hop : inHops ) + if( leafs.contains(hop.getHopID()) ) + tmp.add(hop); + cplans2.put(e.getKey(), new Pair<Hop[],CNodeTpl>( + tmp.toArray(new Hop[0]),tpl)); + } + + //remove cplan w/ single op and w/o agg + if( tpl instanceof CNodeCell && ((CNodeCell)tpl).getCellType()==CellType.NO_AGG + && tpl.getOutput() instanceof CNodeUnary && tpl.getOutput().getInput().get(0) instanceof CNodeData) + cplans2.remove(e.getKey()); + + //remove cplan if empty + if( tpl.getOutput() instanceof CNodeData ) + cplans2.remove(e.getKey()); + } + + return cplans2; + } + + private static void rCollectLeafIDs(CNode node, HashSet<Long> leafs) { + //collect leaf variable names + if( node instanceof CNodeData ) + leafs.add(((CNodeData) node).getHopID()); + + //recursively process cplan + for( CNode c : node.getInput() ) + rCollectLeafIDs(c, leafs); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java new file mode 100644 index 0000000..357d41c --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen; + +import java.util.ArrayList; + +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.MultiThreadedHop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.MemoTable; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.Lop; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.LopsException; +import org.apache.sysml.lops.SpoofFused; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; + +public class SpoofFusedOp extends Hop implements MultiThreadedHop +{ + public enum SpoofOutputDimsType { + INPUT_DIMS, + ROW_DIMS, + COLUMN_DIMS_ROWS, + COLUMN_DIMS_COLS, + SCALAR, + ROW_RANK_DIMS, // right wdivmm + COLUMN_RANK_DIMS // left wdivmm + } + + private Class<?> _class = null; + private boolean _distSupported = false; + private int _numThreads = -1; + private SpoofOutputDimsType _dimsType; + + public SpoofFusedOp ( ) { + + } + + public SpoofFusedOp( String name, DataType dt, ValueType vt, Class<?> cla, boolean dist, SpoofOutputDimsType type ) { + super(name, dt, vt); + _class = cla; + _distSupported = dist; + _dimsType = type; + } + + @Override + public void setMaxNumThreads(int k) { + _numThreads = k; + } + + @Override + public int getMaxNumThreads() { + return _numThreads; + } + + @Override + public boolean allowsAllExecTypes() { + return _distSupported; + } + + @Override + protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { + return OptimizerUtils.estimateSize(dim1, dim2); + } + + @Override + protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { + return 0; + } + + @Override + protected long[] inferOutputCharacteristics(MemoTable memo) { + return null; + } + + @Override + public Lop constructLops() throws HopsException, LopsException { + if( getLops() != null ) + return getLops(); + + ExecType et = optFindExecType(); + + ArrayList<Lop> inputs = new ArrayList<Lop>(); + for( Hop c : getInput() ) + inputs.add(c.constructLops()); + + int k = OptimizerUtils.getConstrainedNumThreads(_numThreads); + SpoofFused lop = new SpoofFused(inputs, getDataType(), getValueType(), _class, k, et); + setOutputDimensions(lop); + setLineNumbers(lop); + setLops(lop); + + return lop; + } + + @Override + protected ExecType optFindExecType() throws HopsException { + + checkAndSetForcedPlatform(); + + if( _etypeForced != null ) { + _etype = _etypeForced; + } + else { + _etype = findExecTypeByMemEstimate(); + checkAndSetInvalidCPDimsAndSize(); + } + + //ensure valid execution plans + if( _etype == ExecType.MR ) + _etype = ExecType.CP; + + return _etype; + } + + @Override + public String getOpString() { + return "spoof("+_class.getSimpleName()+")"; + } + + @Override + public void refreshSizeInformation() { + switch(_dimsType) + { + case ROW_DIMS: + setDim1(getInput().get(0).getDim1()); + setDim2(1); + break; + case COLUMN_DIMS_ROWS: + setDim1(getInput().get(0).getDim2()); + setDim2(1); + break; + case COLUMN_DIMS_COLS: + setDim1(1); + setDim2(getInput().get(0).getDim2()); + break; + case INPUT_DIMS: + setDim1(getInput().get(0).getDim1()); + setDim2(getInput().get(0).getDim2()); + break; + case SCALAR: + setDim1(0); + setDim2(0); + break; + case ROW_RANK_DIMS: + setDim1(getInput().get(0).getDim1()); + setDim2(getInput().get(1).getDim2()); + break; + case COLUMN_RANK_DIMS: + setDim1(getInput().get(0).getDim2()); + setDim2(getInput().get(1).getDim2()); + break; + default: + throw new RuntimeException("Failed to refresh size information " + + "for type: "+_dimsType.toString()); + } + } + + @Override + public Object clone() throws CloneNotSupportedException + { + SpoofFusedOp ret = new SpoofFusedOp(); + + //copy generic attributes + ret.clone(this, false); + + //copy specific attributes + ret._class = _class; + ret._distSupported = _distSupported; + ret._numThreads = _numThreads; + ret._dimsType = _dimsType; + return ret; + } + + @Override + public boolean compare( Hop that ) + { + if( !(that instanceof SpoofFusedOp) ) + return false; + + SpoofFusedOp that2 = (SpoofFusedOp)that; + boolean ret = ( _class.equals(that2._class) + && _distSupported == that2._distSupported + && _numThreads == that2._numThreads + && getInput().size() == that2.getInput().size()); + + if( ret ) { + for( int i=0; i<getInput().size(); i++ ) + ret &= (getInput().get(i) == that2.getInput().get(i)); + } + + return ret; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java new file mode 100644 index 0000000..46637cc --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; + +public abstract class CNode +{ + private static final IDSequence _seq = new IDSequence(); + + protected ArrayList<CNode> _inputs = null; + protected CNode _output = null; + protected boolean _generated = false; + protected String _genVar = null; + protected long _rows = -1; + protected long _cols = -1; + protected DataType _dataType; + protected boolean _literal = false; + + //cached hash to allow memoization in DAG structures and repeated + //recursive hash computation over all inputs (w/ reset on updates) + protected int _hash = 0; + + public CNode() { + _inputs = new ArrayList<CNode>(); + _generated = false; + } + + public ArrayList<CNode> getInput() { + return _inputs; + } + + public String createVarname() { + _genVar = "TMP"+_seq.getNextID(); + return _genVar; + } + + protected String getCurrentVarName() { + return "TMP"+(_seq.getCurrentID()-1); + } + + public String getVarname() { + return _genVar; + } + + public String getClassname() { + return getVarname(); + } + + public void resetGenerated() { + if( _generated ) + for( CNode cn : _inputs ) + cn.resetGenerated(); + _generated = false; + } + + public void setNumRows(long rows) { + _rows = rows; + } + + public long getNumRows() { + return _rows; + } + + public void setNumCols(long cols) { + _cols = cols; + } + + public long getNumCols() { + return _cols; + } + + public DataType getDataType() { + return _dataType; + } + + public void setDataType(DataType dt) { + _dataType = dt; + _hash = 0; + } + + public boolean isLiteral() { + return _literal; + } + + public void setLiteral(boolean literal) { + _literal = literal; + _hash = 0; + } + + public CNode getOutput() { + return _output; + } + + public void setOutput(CNode output) { + _output = output; + _hash = 0; + } + + public abstract String codegen(boolean sparse) ; + + public abstract void setOutputDims(); + + /////////////////////////////////////// + // Functionality for plan cache + + //note: genvar/generated changed on codegen and not considered, + //rows and cols also not include to increase reuse potential + + @Override + public int hashCode() { + if( _hash == 0 ) { + int numIn = _inputs.size(); + int[] tmp = new int[numIn + 3]; + //include inputs, partitioned by matrices and scalars to increase + //reuse in case of interleaved inputs (see CNodeTpl.renameInputs) + int pos = 0; + for( CNode c : _inputs ) + if( c.getDataType()==DataType.MATRIX ) + tmp[pos++] = c.hashCode(); + for( CNode c : _inputs ) + if( c.getDataType()!=DataType.MATRIX ) + tmp[pos++] = c.hashCode(); + tmp[numIn+0] = (_output!=null)?_output.hashCode():0; + tmp[numIn+1] = (_dataType!=null)?_dataType.hashCode():0; + tmp[numIn+2] = Boolean.hashCode(_literal); + _hash = Arrays.hashCode(tmp); + } + return _hash; + } + + @Override + public boolean equals(Object that) { + if( !(that instanceof CNode) ) + return false; + + CNode cthat = (CNode) that; + boolean ret = _inputs.size() == cthat._inputs.size(); + for( int i=0; i<_inputs.size() && ret; i++ ) + ret &= _inputs.get(i).equals(_inputs.get(i)); + return ret + && (_output == cthat._output || _output.equals(cthat._output)) + && _dataType == cthat._dataType + && _literal == cthat._literal; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java new file mode 100644 index 0000000..1bfaab4 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.Arrays; + +import org.apache.sysml.parser.Expression.DataType; + + +public class CNodeBinary extends CNode +{ + public enum BinType { + DOT_PRODUCT, + VECT_MULT_ADD, VECT_DIV_ADD, + VECT_MULT_SCALAR, VECT_DIV_SCALAR, + MULT, DIV, PLUS, MINUS, MODULUS, INTDIV, + LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL, + MIN, MAX, AND, OR, LOG, POW, + MINUS1_MULT; + + public static boolean contains(String value) { + for( BinType bt : values() ) + if( bt.toString().equals(value) ) + return true; + return false; + } + + public boolean isCommutative() { + return ( this == EQUAL || this == NOTEQUAL + || this == PLUS || this == MULT + || this == MIN || this == MAX ); + } + + public String getTemplate(boolean sparse) { + switch (this) { + case DOT_PRODUCT: + return sparse ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, %LEN%);\n" : + " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; + + case VECT_MULT_ADD: + return sparse ? " LibSpoofPrimitives.vectMultiplyAdd(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, %LEN%);\n" : + " LibSpoofPrimitives.vectMultiplyAdd(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n"; + + case VECT_DIV_ADD: + return sparse ? " LibSpoofPrimitives.vectDivAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, %LEN%);\n" : + " LibSpoofPrimitives.vectDivAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n"; + + case VECT_DIV_SCALAR: + return sparse ? " LibSpoofPrimitives.vectDivWrite(%IN1v%, %IN1i%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n" : + " LibSpoofPrimitives.vectDivWrite(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n"; + + case VECT_MULT_SCALAR: + return " LibSpoofPrimitives.vectMultiplyWrite(%IN2%, %IN1%, %POS1%, %OUT%, 0, %LEN%);\n"; + + + /*Can be replaced by function objects*/ + case MULT: + return " double %TMP% = %IN1% * %IN2%;\n" ; + + case DIV: + return " double %TMP% = %IN1% / %IN2%;\n" ; + case PLUS: + return " double %TMP% = %IN1% + %IN2%;\n" ; + case MINUS: + return " double %TMP% = %IN1% - %IN2%;\n" ; + case MODULUS: + return " double %TMP% = %IN1% % %IN2%;\n" ; + case INTDIV: + return " double %TMP% = (int) %IN1% / %IN2%;\n" ; + case LESS: + return " double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n" ; + case LESSEQUAL: + return " double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n" ; + case GREATER: + return " double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n" ; + case GREATEREQUAL: + return " double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n" ; + case EQUAL: + return " double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n" ; + case NOTEQUAL: + return " double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n" ; + + case MIN: + return " double %TMP% = Math.min(%IN1%, %IN2%);\n" ; + case MAX: + return " double %TMP% = Math.max(%IN1%, %IN2%);\n" ; + case LOG: + return " double %TMP% = Math.log(%IN1%)/Math.log(%IN2%);\n" ; + case POW: + return " double %TMP% = Math.pow(%IN1%, %IN2%);\n" ; + case MINUS1_MULT: + return " double %TMP% = 1 - %IN1% * %IN2%;\n" ; + + default: + throw new RuntimeException("Invalid binary type: "+this.toString()); + } + } + } + + private final BinType _type; + + public CNodeBinary( CNode in1, CNode in2, BinType type ) { + //canonicalize commutative matrix-scalar operations + //to increase reuse potential + if( type.isCommutative() && in1 instanceof CNodeData + && in1.getDataType()==DataType.SCALAR ) { + CNode tmp = in1; + in1 = in2; + in2 = tmp; + } + + _inputs.add(in1); + _inputs.add(in2); + _type = type; + setOutputDims(); + } + + public BinType getType() { + return _type; + } + + @Override + public String codegen(boolean sparse) { + if( _generated ) + return ""; + + StringBuilder sb = new StringBuilder(); + + //generate children + sb.append(_inputs.get(0).codegen(sparse)); + sb.append(_inputs.get(1).codegen(sparse)); + + //generate binary operation + String var = createVarname(); + String tmp = _type.getTemplate(sparse); + tmp = tmp.replaceAll("%TMP%", var); + for( int j=1; j<=2; j++ ) { + String varj = _inputs.get(j-1).getVarname(); + if( sparse && !tmp.contains("%IN"+j+"%") ) { + tmp = tmp.replaceAll("%IN"+j+"v%", varj+"vals"); + tmp = tmp.replaceAll("%IN"+j+"i%", varj+"ix"); + } + else + tmp = tmp.replaceAll("%IN"+j+"%", varj ); + + if(varj.startsWith("_b") ) //i.e. b.get(index) + tmp = tmp.replaceAll("%POS"+j+"%", "_bi"); + else + tmp = tmp.replaceAll("%POS"+j+"%", varj+"i"); + } + sb.append(tmp); + + //mark as generated + _generated = true; + + return sb.toString(); + } + + @Override + public String toString() { + switch(_type) { + case DOT_PRODUCT: return "b(dot)"; + case VECT_MULT_ADD: return "b(vma)"; + case VECT_DIV_ADD: return "b(vda)"; + case MULT: return "b(*)"; + case DIV: return "b(/)"; + case VECT_DIV_SCALAR: return "b(vector/)"; + case VECT_MULT_SCALAR: return "b(vector*)"; + default: + return super.toString(); + } + } + + public void setOutputDims() + { + switch(_type) { + //VECT + case VECT_MULT_ADD: + case VECT_DIV_ADD: + _rows = _inputs.get(1)._rows; + _cols = _inputs.get(1)._cols; + _dataType= DataType.MATRIX; + break; + + case VECT_DIV_SCALAR: + case VECT_MULT_SCALAR: + _rows = _inputs.get(0)._rows; + _cols = _inputs.get(0)._cols; + _dataType= DataType.MATRIX; + break; + + + case DOT_PRODUCT: + + //SCALAR Arithmetic + case MULT: + case DIV: + case PLUS: + case MINUS: + case MINUS1_MULT: + case MODULUS: + case INTDIV: + //SCALAR Comparison + case LESS: + case LESSEQUAL: + case GREATER: + case GREATEREQUAL: + case EQUAL: + case NOTEQUAL: + //SCALAR LOGIC + case MIN: + case MAX: + case AND: + case OR: + case LOG: + case POW: + _rows = 0; + _cols = 0; + _dataType= DataType.SCALAR; + break; + } + } + + @Override + public int hashCode() { + if( _hash == 0 ) { + int h1 = super.hashCode(); + int h2 = _type.hashCode(); + _hash = Arrays.hashCode(new int[]{h1,h2}); + } + return _hash; + } + + @Override + public boolean equals(Object o) { + if( !(o instanceof CNodeBinary) ) + return false; + + CNodeBinary that = (CNodeBinary) o; + return super.equals(that) + && _type == that._type; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java new file mode 100644 index 0000000..a9408ca --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; +import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; + +public class CNodeCell extends CNodeTpl +{ + private static final String TEMPLATE = + "package codegen;\n" + + "import java.util.Arrays;\n" + + "import java.io.Serializable;\n" + + "import java.util.ArrayList;\n" + + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n" + + "import org.apache.sysml.runtime.codegen.SpoofCellwise;\n" + + "import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;\n" + + "import org.apache.commons.math3.util.FastMath;\n" + + "\n" + + "public final class %TMP% extends SpoofCellwise {\n" + + " public %TMP%() {\n" + + " _type = CellType.%TYPE%;\n" + + " }\n" + + " protected double genexecDense( double _a, double[][] _b, double[] _scalars, int _n, int _m, int _rowIndex, int _colIndex) { \n" + + "%BODY_dense%" + + " return %OUT%;\n" + + " } \n" + + "}"; + + private CellType _type = null; + private boolean _multipleConsumers = false; + + public CNodeCell(ArrayList<CNode> inputs, CNode output ) { + super(inputs,output); + } + + public void setMultipleConsumers(boolean flag) { + _multipleConsumers = flag; + } + + public boolean hasMultipleConsumers() { + return _multipleConsumers; + } + + public void setCellType(CellType type) { + _type = type; + _hash = 0; + } + + public CellType getCellType() { + return _type; + } + + @Override + public String codegen(boolean sparse) { + String tmp = TEMPLATE; + + //rename inputs + rReplaceDataNode(_output, _inputs.get(0), "_a"); + renameInputs(_inputs, 1); + + //generate dense/sparse bodies + String tmpDense = _output.codegen(false); + _output.resetGenerated(); + + tmp = tmp.replaceAll("%TMP%", createVarname()); + tmp = tmp.replaceAll("%BODY_dense%", tmpDense); + + //return last TMP + tmp = tmp.replaceAll("%OUT%", getCurrentVarName()); + + //replace aggregate information + tmp = tmp.replaceAll("%TYPE%", getCellType().toString()); + + return tmp; + } + + @Override + public void setOutputDims() { + + + } + + @Override + public CNodeTpl clone() { + CNodeCell tmp = new CNodeCell(_inputs, _output); + tmp.setDataType(getDataType()); + tmp.setCellType(getCellType()); + tmp.setMultipleConsumers(hasMultipleConsumers()); + return tmp; + } + + @Override + public SpoofOutputDimsType getOutputDimType() { + switch( _type ) { + case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS; + case ROW_AGG: return SpoofOutputDimsType.ROW_DIMS; + case FULL_AGG: return SpoofOutputDimsType.SCALAR; + default: + throw new RuntimeException("Unsupported cell type: "+_type.toString()); + } + } + + @Override + public int hashCode() { + if( _hash == 0 ) { + int h1 = super.hashCode(); + int h2 = _type.hashCode(); + //note: _multipleConsumers irrelevant for plan comparison + _hash = Arrays.hashCode(new int[]{h1,h2}); + } + return _hash; + } + + @Override + public boolean equals(Object o) { + if(!(o instanceof CNodeCell)) + return false; + + CNodeCell that = (CNodeCell)o; + return super.equals(that) + && _type == that._type; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java new file mode 100644 index 0000000..d5457e8 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.Arrays; + +import org.apache.sysml.hops.Hop; +import org.apache.sysml.parser.Expression.DataType; + +public class CNodeData extends CNode +{ + protected final String _name; + protected final long _hopID; + + public CNodeData(Hop hop) { + this(hop, hop.getDim1(), hop.getDim2(), hop.getDataType()); + } + + public CNodeData(Hop hop, long rows, long cols, DataType dt) { + //note: previous rewrites might have created hops with equal name + //hence, we also keep the hopID to uniquely identify inputs + _name = hop.getName(); + _hopID = hop.getHopID(); + _rows = rows; + _cols = cols; + _dataType = dt; + } + + public CNodeData(CNodeData node, String newName) { + _name = newName; + _hopID = node.getHopID(); + _rows = node.getNumRows(); + _cols = node.getNumCols(); + _dataType = node.getDataType(); + } + + @Override + public String getVarname() { + return _name; + } + + public long getHopID() { + return _hopID; + } + + @Override + public String codegen(boolean sparse) { + return ""; + } + + @Override + public void setOutputDims() { + + } + + @Override + public String toString() { + return "CdataNode[name="+_name+", id="+_hopID+"]"; + } + + @Override + public int hashCode() { + if( _hash == 0 ) { + int h1 = super.hashCode(); + int h2 = isLiteral() ? _name.hashCode() : 0; + _hash = Arrays.hashCode(new int[]{h1,h2}); + } + return _hash; + } + + @Override + public boolean equals(Object o) { + return (o instanceof CNodeData + && super.equals(o) + && (!isLiteral() || _name.equals(((CNodeData)o)._name))); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java new file mode 100644 index 0000000..8c2e38c --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; +import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; + + +public class CNodeOuterProduct extends CNodeTpl +{ + private static final String TEMPLATE = + "package codegen;\n" + + "import java.util.Arrays;\n" + + "import java.util.ArrayList;\n" + + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n" + + "import org.apache.sysml.runtime.codegen.SpoofOuterProduct;\n" + + "import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;\n" + + "import org.apache.commons.math3.util.FastMath;\n" + + "\n" + + "public final class %TMP% extends SpoofOuterProduct { \n" + + " public %TMP%() {\n" + + " _outerProductType = OutProdType.%TYPE%;\n" + + " }\n" + + " protected void genexecDense( double _a, double[] _a1, int _a1i, double[] _a2, int _a2i, double[][] _b, double[] _scalars, double[] _c, int _ci, int _n, int _m, int _k, int _rowIndex, int _colIndex) { \n" + + "%BODY_dense%" + + " } \n" + + " protected double genexecCellwise( double _a, double[] _a1, int _a1i, double[] _a2, int _a2i, double[][] _b, double[] _scalars, int _n, int _m, int _k, int _rowIndex, int _colIndex) { \n" + + "%BODY_cellwise%" + + " return %OUT_cellwise%;\n" + + " } \n" + + + "}"; + + private OutProdType _type = null; + private boolean _transposeOutput = false; + + public CNodeOuterProduct(ArrayList<CNode> inputs, CNode output ) { + super(inputs,output); + } + + @Override + public String codegen(boolean sparse) { + // note: ignore sparse flag, generate both + String tmp = TEMPLATE; + + //rename inputs + rReplaceDataNode(_output, _inputs.get(0), "_a"); + rReplaceDataNode(_output, _inputs.get(1), "_a1"); // u + rReplaceDataNode(_output, _inputs.get(2), "_a2"); // v + renameInputs(_inputs, 3); + + //generate dense/sparse bodies + String tmpDense = _output.codegen(false); + _output.resetGenerated(); + + tmp = tmp.replaceAll("%TMP%", createVarname()); + + if(_type == OutProdType.LEFT_OUTER_PRODUCT || _type == OutProdType.RIGHT_OUTER_PRODUCT) { + tmp = tmp.replaceAll("%BODY_dense%", tmpDense); + tmp = tmp.replaceAll("%OUT%", "_c"); + tmp = tmp.replaceAll("%BODY_cellwise%", ""); + tmp = tmp.replaceAll("%OUT_cellwise%", "0"); + } + else { + tmp = tmp.replaceAll("%BODY_dense%", ""); + tmp = tmp.replaceAll("%BODY_cellwise%", tmpDense); + tmp = tmp.replaceAll("%OUT_cellwise%", getCurrentVarName()); + } + //replace size information + tmp = tmp.replaceAll("%LEN%", "_k"); + + tmp = tmp.replaceAll("%POSOUT%", "_ci"); + + tmp = tmp.replaceAll("%TYPE%", _type.toString()); + + return tmp; + } + + public void setOutProdType(OutProdType type) { + _type = type; + _hash = 0; + } + + public OutProdType getOutProdType() { + return _type; + } + + @Override + public void setOutputDims() { + + } + + public void setTransposeOutput(boolean transposeOutput) { + _transposeOutput = transposeOutput; + _hash = 0; + } + + + public boolean isTransposeOutput() { + return _transposeOutput; + } + + @Override + public SpoofOutputDimsType getOutputDimType() { + switch( _type ) { + case LEFT_OUTER_PRODUCT: + return SpoofOutputDimsType.COLUMN_RANK_DIMS; + case RIGHT_OUTER_PRODUCT: + return SpoofOutputDimsType.ROW_RANK_DIMS; + case CELLWISE_OUTER_PRODUCT: + return SpoofOutputDimsType.INPUT_DIMS; + case AGG_OUTER_PRODUCT: + return SpoofOutputDimsType.SCALAR; + default: + throw new RuntimeException("Unsupported outer product type: "+_type.toString()); + } + } + + @Override + public CNodeTpl clone() { + return new CNodeOuterProduct(_inputs, _output); + } + + @Override + public int hashCode() { + if( _hash == 0 ) { + int h1 = super.hashCode(); + int h2 = _type.hashCode(); + int h3 = Boolean.hashCode(_transposeOutput); + _hash = Arrays.hashCode(new int[]{h1,h2,h3}); + } + return _hash; + } + + @Override + public boolean equals(Object o) { + if(!(o instanceof CNodeOuterProduct)) + return false; + + CNodeOuterProduct that = (CNodeOuterProduct)o; + return super.equals(that) + && _type == that._type + && _transposeOutput == that._transposeOutput; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRowAggVector.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRowAggVector.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRowAggVector.java new file mode 100644 index 0000000..147615f --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRowAggVector.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.ArrayList; + +import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; + +public class CNodeRowAggVector extends CNodeTpl +{ + private static final String TEMPLATE = + "package codegen;\n" + + "import java.util.Arrays;\n" + + "import java.util.ArrayList;\n" + + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n" + + "import org.apache.sysml.runtime.codegen.SpoofRowAggregate;\n" + + "\n" + + "public final class %TMP% extends SpoofRowAggregate { \n" + + " public %TMP%() {\n" + + " _colVector = %FLAG%;\n" + + " }\n" + + " protected void genexecRowDense( double[] _a, int _ai, double[][] _b, double[] _scalars, double[] _c, int _len, int _rowIndex ) { \n" + + "%BODY_dense%" + + " } \n" + + " protected void genexecRowSparse( double[] _avals, int[] _aix, int _ai, double[][] _b, double[] _scalars, double[] _c, int _len, int _rowIndex ) { \n" + + "%BODY_sparse%" + + " } \n" + + "}\n"; + + public CNodeRowAggVector(ArrayList<CNode> inputs, CNode output ) { + super(inputs, output); + } + + + @Override + public String codegen(boolean sparse) { + // note: ignore sparse flag, generate both + String tmp = TEMPLATE; + + //rename inputs + rReplaceDataNode(_output, _inputs.get(0), "_a"); // input matrix + renameInputs(_inputs, 1); + + //generate dense/sparse bodies + String tmpDense = _output.codegen(false); + _output.resetGenerated(); + String tmpSparse = _output.codegen(true); + tmp = tmp.replaceAll("%TMP%", createVarname()); + tmp = tmp.replaceAll("%BODY_dense%", tmpDense); + tmp = tmp.replaceAll("%BODY_sparse%", tmpSparse); + + //replace outputs + tmp = tmp.replaceAll("%OUT%", "_c"); + tmp = tmp.replaceAll("%POSOUT%", "0"); + + //replace size information + tmp = tmp.replaceAll("%LEN%", "_len"); + + //replace colvector information and start position + tmp = tmp.replaceAll("%FLAG%", String.valueOf(_output._cols==1)); + tmp = tmp.replaceAll("_bi", "0"); + + return tmp; + } + + @Override + public void setOutputDims() { + // TODO Auto-generated method stub + + } + + @Override + public SpoofOutputDimsType getOutputDimType() { + return (_output._cols==1) ? + SpoofOutputDimsType.COLUMN_DIMS_ROWS : //column vector + SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector + } + + @Override + public CNodeTpl clone() { + return new CNodeRowAggVector(_inputs, _output); + } + + @Override + public int hashCode() { + return super.hashCode(); + } + + @Override + public boolean equals(Object o) { + return (o instanceof CNodeRowAggVector + && super.equals(o)); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java new file mode 100644 index 0000000..719770b --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; + +import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.parser.Expression.DataType; + +public abstract class CNodeTpl extends CNode implements Cloneable +{ + public CNodeTpl(ArrayList<CNode> inputs, CNode output ) { + if(inputs.size() < 1) + throw new RuntimeException("Cannot pass empty inputs to the CNodeTpl"); + + for(CNode input : inputs) + addInput(input); + _output = output; + } + + public void addInput(CNode in) { + //check for duplicate entries or literals + if( containsInput(in) || in.isLiteral() ) + return; + + _inputs.add(in); + } + + public void cleanupInputs(HashSet<Long> filter) { + ArrayList<CNode> tmp = new ArrayList<CNode>(); + for( CNode in : _inputs ) + if( in instanceof CNodeData && filter.contains(((CNodeData) in).getHopID()) ) + tmp.add(in); + _inputs = tmp; + } + + public String codegen() { + return codegen(false); + } + + public abstract CNodeTpl clone(); + + public abstract SpoofOutputDimsType getOutputDimType(); + + protected void renameInputs(ArrayList<CNode> inputs, int startIndex) { + //create map of hopID to data nodes with new names, used for CSE + HashMap<Long, CNode> nodes = new HashMap<Long, CNode>(); + for(int i=startIndex, sPos=0, mPos=0; i < inputs.size(); i++) { + CNode cnode = inputs.get(i); + if( !(cnode instanceof CNodeData) || ((CNodeData)cnode).isLiteral()) + continue; + CNodeData cdata = (CNodeData)cnode; + if( cdata.getDataType() == DataType.SCALAR || ( cdata.getNumCols() == 0 && cdata.getNumRows() == 0) ) + nodes.put(cdata.getHopID(), new CNodeData(cdata, "_scalars["+ mPos++ +"]")); + else + nodes.put(cdata.getHopID(), new CNodeData(cdata, "_b["+ sPos++ +"]")); + } + + //single pass to replace all names + rReplaceDataNode(_output, nodes, new HashMap<Long, CNode>()); + } + + protected void rReplaceDataNode( CNode root, CNode input, String newName ) { + if( !(input instanceof CNodeData) ) + return; + + //create temporary name mapping + HashMap<Long, CNode> names = new HashMap<Long, CNode>(); + CNodeData tmp = (CNodeData)input; + names.put(tmp.getHopID(), new CNodeData(tmp, newName)); + + rReplaceDataNode(root, names, new HashMap<Long,CNode>()); + } + + /** + * Recursively searches for data nodes and replaces them if found. + * + * @param node current node in recursive descend + * @param dnodes prepared data nodes, identified by own hop id + * @param lnodes memoized lookup nodes, identified by data node hop id + */ + protected void rReplaceDataNode( CNode node, HashMap<Long, CNode> dnodes, HashMap<Long, CNode> lnodes ) + { + for( int i=0; i<node._inputs.size(); i++ ) { + //recursively process children + rReplaceDataNode(node._inputs.get(i), dnodes, lnodes); + + //replace leaf data node + if( node._inputs.get(i) instanceof CNodeData ) { + CNodeData tmp = (CNodeData)node._inputs.get(i); + if( dnodes.containsKey(tmp.getHopID()) ) + node._inputs.set(i, dnodes.get(tmp.getHopID())); + } + + //replace lookup on top of leaf data node + if( node._inputs.get(i) instanceof CNodeUnary + && ((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP) { + CNodeData tmp = (CNodeData)node._inputs.get(i)._inputs.get(0); + if( !lnodes.containsKey(tmp.getHopID()) ) + lnodes.put(tmp.getHopID(), node._inputs.get(i)); + else + node._inputs.set(i, lnodes.get(tmp.getHopID())); + } + } + } + + public void rReplaceDataNode( CNode node, long hopID, CNode newNode ) + { + for( int i=0; i<node._inputs.size(); i++ ) { + //replace leaf node + if( node._inputs.get(i) instanceof CNodeData ) { + CNodeData tmp = (CNodeData)node._inputs.get(i); + if( tmp.getHopID() == hopID ) + node._inputs.set(i, newNode); + } + //recursively process children + rReplaceDataNode(node._inputs.get(i), hopID, newNode); + + //remove unnecessary lookups + if( node._inputs.get(i) instanceof CNodeUnary + && ((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP + && node._inputs.get(i)._inputs.get(0).getDataType()==DataType.SCALAR) + node._inputs.set(i, node._inputs.get(i)._inputs.get(0)); + } + } + + public void rInsertLookupNode( CNode node, long hopID, HashMap<Long, CNode> memo ) + { + for( int i=0; i<node._inputs.size(); i++ ) { + //recursively process children + rInsertLookupNode(node._inputs.get(i), hopID, memo); + + //replace leaf node + if( node._inputs.get(i) instanceof CNodeData ) { + CNodeData tmp = (CNodeData)node._inputs.get(i); + if( tmp.getHopID() == hopID ) { + //use memo structure to retain DAG structure + CNode lookup = memo.get(hopID); + if( lookup == null ) { + lookup = new CNodeUnary(tmp, UnaryType.LOOKUP); + memo.put(hopID, lookup); + } + node._inputs.set(i, lookup); + } + } + } + } + + /** + * Checks for duplicates (object ref or varname). + * + * @param input new input node + * @return true if duplicate, false otherwise + */ + private boolean containsInput(CNode input) { + if( !(input instanceof CNodeData) ) + return false; + + CNodeData input2 = (CNodeData)input; + for( CNode cnode : _inputs ) { + if( !(cnode instanceof CNodeData) ) + continue; + CNodeData cnode2 = (CNodeData)cnode; + if( cnode2._name.equals(input2._name) && cnode2._hopID==input2._hopID ) + return true; + } + + return false; + } + + @Override + public int hashCode() { + return super.hashCode(); + } + + @Override + public boolean equals(Object o) { + return (o instanceof CNodeTpl + && super.equals(o)); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java new file mode 100644 index 0000000..f08769e --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.Arrays; + +import org.apache.sysml.parser.Expression.DataType; + + +public class CNodeUnary extends CNode +{ + public enum UnaryType { + ROW_SUMS, LOOKUP, LOOKUP0, + EXP, POW2, MULT2, SQRT, LOG, + ABS, ROUND, CEIL,FLOOR, SIGN, + SIN, COS, TAN, ASIN, ACOS, ATAN, + IQM, STOP, + DOTPRODUCT_ROW_SUMS; //row sums via dot product for debugging purposes + + public static boolean contains(String value) { + for( UnaryType ut : values() ) + if( ut.toString().equals(value) ) + return true; + return false; + } + + public String getTemplate(boolean sparse) { + switch (this) { + case ROW_SUMS: + return sparse ? " double %TMP% = LibSpoofPrimitives.vectSum( %IN1v%, %IN1i%, %POS1%, %LEN%);\n": + " double %TMP% = LibSpoofPrimitives.vectSum( %IN1%, %POS1%, %LEN%);\n"; + case EXP: + return " double %TMP% = FastMath.exp(%IN1%);\n"; + case LOOKUP: + return " double %TMP% = %IN1%[_rowIndex];\n" ; + case LOOKUP0: + return " double %TMP% = %IN1%[0];\n" ; + case POW2: + return " double %TMP% = %IN1% * %IN1%;\n" ; + case MULT2: + return " double %TMP% = %IN1% + %IN1%;\n" ; + case ABS: + return " double %TMP% = Math.abs(%IN1%);\n"; + case SIN: + return " double %TMP% = Math.sin(%IN1%);\n"; + case COS: + return " double %TMP% = Math.cos(%IN1%);\n"; + case TAN: + return " double %TMP% = Math.tan(%IN1%);\n"; + case ASIN: + return " double %TMP% = Math.asin(%IN1%);\n"; + case ACOS: + return " double %TMP% = Math.acos(%IN1%);\n"; + case ATAN: + return " double %TMP% = Math.atan(%IN1%);\n"; + case SIGN: + return " double %TMP% = Math.signum(%IN1%);\n"; + case SQRT: + return " double %TMP% = Math.sqrt(%IN1%);\n"; + case LOG: + return " double %TMP% = FastMath.log(%IN1%);\n"; + case ROUND: + return " double %TMP% = Math.round(%IN1%);\n"; + case CEIL: + return " double %TMP% = Math.ceil(%IN1%);\n"; + case FLOOR: + return " double %TMP% = Math.floor(%IN1%);\n"; + default: + throw new RuntimeException("Invalid binary type: "+this.toString()); + } + } + } + + private final UnaryType _type; + + public CNodeUnary( CNode in1, UnaryType type ) { + _inputs.add(in1); + _type = type; + setOutputDims(); + } + + public UnaryType getType() { + return _type; + } + + @Override + public String codegen(boolean sparse) { + if( _generated ) + return ""; + + StringBuilder sb = new StringBuilder(); + + //generate children + sb.append(_inputs.get(0).codegen(sparse)); + + //generate binary operation + String var = createVarname(); + String tmp = _type.getTemplate(sparse); + tmp = tmp.replaceAll("%TMP%", var); + + String varj = _inputs.get(0).getVarname(); + if( sparse && !tmp.contains("%IN1%") ) { + tmp = tmp.replaceAll("%IN1v%", varj+"vals"); + tmp = tmp.replaceAll("%IN1i%", varj+"ix"); + } + else + tmp = tmp.replaceAll("%IN1%", varj ); + + if(varj.startsWith("_b") ) //i.e. b.get(index) + { + tmp = tmp.replaceAll("%POS1%", "_bi"); + tmp = tmp.replaceAll("%POS2%", "_bi"); + } + tmp = tmp.replaceAll("%POS1%", varj+"i"); + tmp = tmp.replaceAll("%POS2%", varj+"i"); + + sb.append(tmp); + + //mark as generated + _generated = true; + + return sb.toString(); + } + + @Override + public String toString() { + switch(_type) { + case ROW_SUMS: return "u(R+)"; + default: + return super.toString(); + } + } + + @Override + public void setOutputDims() { + switch(_type) + { + case ROW_SUMS: + case EXP: + case LOOKUP: + case LOOKUP0: + case POW2: + case MULT2: + case ABS: + case SIN: + case COS: + case TAN: + case ASIN: + case ACOS: + case ATAN: + case SIGN: + case SQRT: + case LOG: + case ROUND: + case IQM: + case STOP: + case CEIL: + case FLOOR: + _rows = 0; + _cols = 0; + _dataType= DataType.SCALAR; + break; + default: + throw new RuntimeException("Operation " + _type.toString() + " has no " + + "output dimensions, dimensions needs to be specified for the CNode " ); + } + + } + + @Override + public int hashCode() { + if( _hash == 0 ) { + int h1 = super.hashCode(); + int h2 = _type.hashCode(); + _hash = Arrays.hashCode(new int[]{h1,h2}); + } + return _hash; + } + + @Override + public boolean equals(Object o) { + if( !(o instanceof CNodeUnary) ) + return false; + + CNodeUnary that = (CNodeUnary) o; + return super.equals(that) + && _type == that._type; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/BaseTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/BaseTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/BaseTpl.java new file mode 100644 index 0000000..4b7ecbf --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/template/BaseTpl.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.template; + +import java.util.ArrayList; +import java.util.LinkedHashMap; + +import org.apache.sysml.api.DMLException; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.runtime.matrix.data.Pair; + +public abstract class BaseTpl +{ + public enum TemplateType { + CellTpl, + OuterProductTpl, + RowAggTpl + } + + private TemplateType _type = null; + + protected ArrayList<Hop> _matrixInputs = new ArrayList<Hop>(); + protected Hop _initialHop; + protected Hop _endHop; + protected ArrayList<CNodeData> _initialCnodes = new ArrayList<CNodeData>(); + protected ArrayList<Hop> _adddedMatrices = new ArrayList<Hop>(); + protected boolean _endHopReached = false; + + protected LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> _cpplans = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>(); + + protected BaseTpl(TemplateType type) { + _type = type; + } + + public TemplateType getType() { + return _type; + } + + public abstract boolean openTpl(Hop hop); + + public abstract boolean findTplBoundaries(Hop initialHop, CplanRegister cplanRegister); + + public abstract LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals) throws DMLException; +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java new file mode 100644 index 0000000..0c841e8 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.template; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map.Entry; + +import org.apache.sysml.api.DMLException; +import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; +import org.apache.sysml.hops.codegen.cplan.CNodeCell; +import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; +import org.apache.sysml.runtime.matrix.data.Pair; + +public class CellTpl extends BaseTpl +{ + + public CellTpl() { + super(TemplateType.CellTpl); + } + + @Override + public boolean openTpl(Hop hop) { + return isValidOperation(hop); + } + + @Override + public boolean findTplBoundaries(Hop initialHop, CplanRegister cplanRegister) { + _initialHop = initialHop; + rFindCellwisePattern(initialHop, new HashMap<Long, Hop>()); + + //if cplanRegister has the initial hop then no need to reconstruct + if(cplanRegister.containsHop(TemplateType.CellTpl, _initialHop.getHopID())) + return false; + + //re-assign initialHop to fuse the sum/rowsums (before checking for chains) + for (Hop h : _initialHop.getParent()) + if( h instanceof AggUnaryOp && ((AggUnaryOp) h).getOp() == AggOp.SUM + && ((AggUnaryOp) h).getDirection()!= Direction.Col ) { + _initialHop = h; + } + + //unary matrix && endHop found && endHop is not direct child of the initialHop (i.e., chain of operators) + if(_endHop != null && _endHop != _initialHop) + { + + // if final hop is unary add its child to the input + if(_endHop instanceof UnaryOp) + _matrixInputs.add(_endHop.getInput().get(0)); + //if one input is scalar then add the other as major input + else if(_endHop.getInput().get(0).getDataType() == DataType.SCALAR) + _matrixInputs.add(_endHop.getInput().get(1)); + else if(_endHop.getInput().get(1).getDataType() == DataType.SCALAR) + _matrixInputs.add(_endHop.getInput().get(0)); + //if one is matrix and the other is vector add the matrix + else if(TemplateUtils.isMatrix(_endHop.getInput().get(0)) && TemplateUtils.isVector(_endHop.getInput().get(1)) ) + _matrixInputs.add(_endHop.getInput().get(0)); + else if(TemplateUtils.isMatrix(_endHop.getInput().get(1)) && TemplateUtils.isVector(_endHop.getInput().get(0)) ) + _matrixInputs.add(_endHop.getInput().get(1)); + //both are vectors (add any of them) + else + _matrixInputs.add(_endHop.getInput().get(0)); + + return true; + } + + return false; + } + + private void rFindCellwisePattern(Hop h, HashMap<Long,Hop> memo) + { + if(memo.containsKey(h.getHopID())) + return; + + //stop recursion if stopping operator + if(h.getDataType() == DataType.SCALAR || !isValidOperation(h)) + return; + + //process childs recursively + _endHop = h; + for( Hop in : h.getInput() ) + { + //propagate the _endHop from bottom to top + if(memo.containsKey(in.getHopID())) + _endHop=memo.get(in.getHopID()); + else + rFindCellwisePattern(in,memo); + } + + memo.put(h.getHopID(), _endHop); + } + + @Override + public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals) + throws DMLException { + //re-assign the dimensions of inputs to match the generated code dimensions + _initialCnodes.add(new CNodeData(_matrixInputs.get(0), 1, 1, DataType.SCALAR)); + + rConstructCellCplan(_initialHop,_initialHop, new HashSet<Long>(), compileLiterals); + return _cpplans; + } + + public CNode fuseCellWise(Hop initialHop,Hop matrixInput, boolean compileLiterals) + throws DMLException { + //re-assign the dimensions of inputs to match the generated code dimensions + _initialHop = initialHop; + _matrixInputs.add(matrixInput); + + constructTplCplan(compileLiterals); + Entry<Long, Pair<Hop[],CNodeTpl>> toplevel = TemplateUtils.getTopLevelCpplan(_cpplans); + if(toplevel != null) + return toplevel.getValue().getValue().getOutput(); + else + return null; + } + + private void rConstructCellCplan(Hop root, Hop hop, HashSet<Long> memo, boolean compileLiterals) + throws DMLException + { + if( memo.contains(hop.getHopID()) ) + return; + + + //process childs recursively + for( Hop c : hop.getInput() ) + rConstructCellCplan(root, c, memo, compileLiterals); + + // first hop to enter here should be _endHop + if(TemplateUtils.inputsAreGenerated(hop,_matrixInputs,_cpplans)) + // if direct children are DataGenOps, literals, or already in the cpplans then we are ready to generate code + { + CNodeCell cellTmpl = null; + + //Fetch operands + CNode out = null; + ArrayList<CNode> addedCNodes = new ArrayList<CNode>(); + ArrayList<Hop> addedHops = new ArrayList<Hop>(); + ArrayList<CNode> cnodeData = TemplateUtils.fetchOperands(hop, _cpplans, addedCNodes, addedHops, _initialCnodes, compileLiterals); + + //if operands are scalar or independent from X + boolean independentOperands = hop != root && (hop.getDataType() == DataType.SCALAR || TemplateUtils.isOperandsIndependent(cnodeData, addedHops, new String[] {_matrixInputs.get(0).getName()})); + if(!independentOperands) + { + if(hop instanceof UnaryOp) + { + CNode cdata1 = cnodeData.get(0); + + //Primitive Operation haas the same name as Hop Type OpOp1 + String primitiveOpName = ((UnaryOp)hop).getOp().toString(); + out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); + } + else if(hop instanceof BinaryOp) + { + BinaryOp bop = (BinaryOp) hop; + CNode cdata1 = cnodeData.get(0); + CNode cdata2 = cnodeData.get(1); + + //Primitive Operation has the same name as Hop Type OpOp2 + String primitiveOpName = bop.getOp().toString(); + + //cdata1 is vector + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP); + + //cdata2 is vector + if( TemplateUtils.isColVector(cdata2) ) + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP); + + + if( bop.getOp()==OpOp2.POW && cdata2.isLiteral() && cdata2.getVarname().equals("2") ) + out = new CNodeUnary(cdata1, UnaryType.POW2); + else if( bop.getOp()==OpOp2.MULT && cdata2.isLiteral() && cdata2.getVarname().equals("2") ) + out = new CNodeUnary(cdata1, UnaryType.MULT2); + else //default binary + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); + } + else if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM + && (((AggUnaryOp) hop).getDirection() == Direction.RowCol + || ((AggUnaryOp) hop).getDirection() == Direction.Row) && root == hop) + { + out = cnodeData.get(0); + } + } + // wire output to the template + if(out != null || independentOperands) + { + if(_cpplans.isEmpty()) + { + //first initialization has to have the first variable as input + ArrayList<CNode> initialInputs = new ArrayList<CNode>(); + + if(independentOperands) // pass the hop itself as an input instead of its children + { + CNode c = new CNodeData(hop); + initialInputs.addAll(_initialCnodes); + initialInputs.add(c); + cellTmpl = new CNodeCell(initialInputs, c); + cellTmpl.setDataType(hop.getDataType()); + cellTmpl.setCellType(CellType.NO_AGG); + cellTmpl.setMultipleConsumers(hop.getParent().size()>1); + + _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(new Hop[] {_matrixInputs.get(0),hop} ,cellTmpl)); + } + else + { + initialInputs.addAll(_initialCnodes); + initialInputs.addAll(cnodeData); + cellTmpl = new CNodeCell(initialInputs, out); + cellTmpl.setDataType(hop.getDataType()); + cellTmpl.setCellType(CellType.NO_AGG); + cellTmpl.setMultipleConsumers(hop.getParent().size()>1); + + //Hop[] hopArray = new Hop[hop.getInput().size()+1]; + Hop[] hopArray = new Hop[addedHops.size()+1]; + hopArray[0] = _matrixInputs.get(0); + + //System.arraycopy( hop.getInput().toArray(), 0, hopArray, 1, hop.getInput().size()); + System.arraycopy( addedHops.toArray(), 0, hopArray, 1, addedHops.size()); + + _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(hopArray,cellTmpl)); + } + } + else + { + if(independentOperands) + { + CNode c = new CNodeData(hop); + //clear Operands + addedCNodes.clear(); + addedHops.clear(); + + //added the current hop as the input + addedCNodes.add(c); + addedHops.add(hop); + out = c; + } + //wire the output to existing or new template + TemplateUtils.setOutputToExistingTemplate(hop, out, _cpplans, addedCNodes, addedHops); + } + } + memo.add(hop.getHopID()); + } + } + + private boolean isValidOperation(Hop hop) { + boolean isBinaryMatrixScalar = hop instanceof BinaryOp && hop.getDataType()==DataType.MATRIX && + (hop.getInput().get(0).getDataType()==DataType.SCALAR || hop.getInput().get(1).getDataType()==DataType.SCALAR); + boolean isBinaryMatrixVector = hop instanceof BinaryOp && hop.dimsKnown() && + ((hop.getInput().get(0).getDataType() == DataType.MATRIX + && TemplateUtils.isVectorOrScalar(hop.getInput().get(1)) && !TemplateUtils.isBinaryMatrixRowVector(hop)) + ||(TemplateUtils.isVectorOrScalar( hop.getInput().get(0)) + && hop.getInput().get(1).getDataType() == DataType.MATRIX && !TemplateUtils.isBinaryMatrixRowVector(hop)) ); + return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) + && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector); + } +}
