[SYSTEMML-1914] Extended compression rewrite (auto compression, default) This patch extends the existing compression rewrite (true/false, statically applied to persistent reads of multi-column matrices) by auto compression. If configured with 'auto' (which is enabled by default), we now automatically decide to compress or not. Initially, we follow a conservative approach: we decide to compress matrices that are known to exceed aggregate cluster memory, that are not ultra sparse, and only if they are used in loops and all operations are supported over compressed matrices. In order to analyze the individual operations, we make a full pass over the program with awareness of functions and transitive assignments of compressed matrices.
On L2SVM (10 outer iterations, tol=10^-9) and Mnist240m (~540GB in uncompressed serialized form), this change improved performance from 5,913s to 1,952s. At the same time, the overhead for program analyze is negligible taking 9ms and 0.8ms on L2SVM for the first and second rewrite applicable. This program analysis is conditional on all the other conditions (incl large out-of-core matrices) and thus, do not affect the compilation of programs over small data. Furthermore, this patch also makes a couple of minor cleanups such as the handling of function keys, variable sets, and hop primitives. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f991e4b4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f991e4b4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f991e4b4 Branch: refs/heads/master Commit: f991e4b48a578f5bc716a564e701a110d44c2cda Parents: 945ca7b Author: Matthias Boehm <mboe...@gmail.com> Authored: Sat Sep 16 15:55:26 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sat Sep 16 17:21:46 2017 -0700 ---------------------------------------------------------------------- conf/SystemML-config.xml.template | 2 +- .../java/org/apache/sysml/conf/DMLConfig.java | 6 +- .../java/org/apache/sysml/hops/FunctionOp.java | 23 +- src/main/java/org/apache/sysml/hops/Hop.java | 34 +-- .../sysml/hops/ipa/FunctionCallGraph.java | 2 +- .../sysml/hops/ipa/InterProceduralAnalysis.java | 14 +- .../sysml/hops/rewrite/ProgramRewriter.java | 6 +- .../hops/rewrite/RewriteCompressedReblock.java | 306 +++++++++++++++++-- .../RewriteSplitDagDataDependentOperators.java | 2 +- .../java/org/apache/sysml/lops/Compression.java | 13 +- .../org/apache/sysml/parser/VariableSet.java | 5 + .../parfor/opt/OptTreeConverter.java | 9 +- .../parfor/opt/OptTreePlanChecker.java | 6 +- 13 files changed, 343 insertions(+), 85 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/conf/SystemML-config.xml.template ---------------------------------------------------------------------- diff --git a/conf/SystemML-config.xml.template b/conf/SystemML-config.xml.template index aaf7316..1731a3b 100644 --- a/conf/SystemML-config.xml.template +++ b/conf/SystemML-config.xml.template @@ -55,7 +55,7 @@ <cp.parallel.io>true</cp.parallel.io> <!-- enables compressed linear algebra, experimental feature --> - <compressed.linalg>false</compressed.linalg> + <compressed.linalg>auto</compressed.linalg> <!-- enables operator fusion via code generation, experimental feature --> <codegen.enabled>false</codegen.enabled> http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/conf/DMLConfig.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java b/src/main/java/org/apache/sysml/conf/DMLConfig.java index e5999fe..b1d0a2e 100644 --- a/src/main/java/org/apache/sysml/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java @@ -41,6 +41,7 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.codegen.SpoofCompiler.CompilerType; +import org.apache.sysml.lops.Compression; import org.apache.sysml.parser.ParseException; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.io.IOUtilFunctions; @@ -72,12 +73,13 @@ public class DMLConfig public static final String YARN_APPQUEUE = "dml.yarn.app.queue"; public static final String CP_PARALLEL_OPS = "cp.parallel.ops"; public static final String CP_PARALLEL_IO = "cp.parallel.io"; - public static final String COMPRESSED_LINALG = "compressed.linalg"; + public static final String COMPRESSED_LINALG = "compressed.linalg"; //auto, true, false public static final String NATIVE_BLAS = "native.blas"; public static final String CODEGEN = "codegen.enabled"; //boolean public static final String CODEGEN_COMPILER = "codegen.compiler"; //see SpoofCompiler.CompilerType public static final String CODEGEN_PLANCACHE = "codegen.plancache"; //boolean public static final String CODEGEN_LITERALS = "codegen.literals"; //1..heuristic, 2..always + public static final String EXTRA_FINEGRAINED_STATS = "systemml.stats.finegrained"; //boolean public static final String STATS_MAX_WRAP_LEN = "systemml.stats.maxWrapLength"; //int public static final String EXTRA_GPU_STATS = "systemml.stats.extraGPU"; //boolean @@ -120,7 +122,7 @@ public class DMLConfig _defaultVals.put(YARN_APPQUEUE, "default" ); _defaultVals.put(CP_PARALLEL_OPS, "true" ); _defaultVals.put(CP_PARALLEL_IO, "true" ); - _defaultVals.put(COMPRESSED_LINALG, "false" ); + _defaultVals.put(COMPRESSED_LINALG, Compression.CompressConfig.AUTO.name() ); _defaultVals.put(CODEGEN, "false" ); _defaultVals.put(CODEGEN_COMPILER, CompilerType.AUTO.name() ); _defaultVals.put(CODEGEN_PLANCACHE, "true" ); http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/hops/FunctionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index d3daa15..62872ea 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -26,6 +26,7 @@ import org.apache.sysml.lops.FunctionCallCPSingle; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.controlprogram.Program; @@ -84,19 +85,21 @@ public class FunctionOp extends Hop /** FunctionOps may have any number of inputs. */ @Override public void checkArity() throws HopsException {} - - public String getFunctionNamespace() - { + + public String getFunctionKey() { + return DMLProgram.constructFunctionKey( + getFunctionNamespace(), getFunctionName()); + } + + public String getFunctionNamespace() { return _fnamespace; } - public String getFunctionName() - { + public String getFunctionName() { return _fname; } - public void setFunctionName( String fname ) - { + public void setFunctionName( String fname ) { _fname = fname; } @@ -104,13 +107,11 @@ public class FunctionOp extends Hop return _outputHops; } - public String[] getOutputVariableNames() - { + public String[] getOutputVariableNames() { return _outputs; } - public FunctionType getFunctionType() - { + public FunctionType getFunctionType() { return _type; } http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 7ccfe2e..5ee0b56 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -231,14 +231,6 @@ public abstract class Hop implements ParseInfo } } } - - public void setRequiresReblock(boolean flag) { - _requiresReblock = flag; - } - - public void setRequiresCompression(boolean flag) { - _requiresCompression = flag; - } public boolean hasMatrixInputWithDifferentBlocksizes() { @@ -254,27 +246,35 @@ public abstract class Hop implements ParseInfo return false; } - public void setOutputBlocksizes( long brlen, long bclen ) - { + public void setOutputBlocksizes( long brlen, long bclen ) { setRowsInBlock( brlen ); setColsInBlock( bclen ); } - public boolean requiresReblock() - { + public void setRequiresReblock(boolean flag) { + _requiresReblock = flag; + } + + public boolean requiresReblock() { return _requiresReblock; } - public void setRequiresCheckpoint(boolean flag) - { + public void setRequiresCheckpoint(boolean flag) { _requiresCheckpoint = flag; } - public boolean requiresCheckpoint() - { + public boolean requiresCheckpoint() { return _requiresCheckpoint; } - + + public void setRequiresCompression(boolean flag) { + _requiresCompression = flag; + } + + public boolean requiresCompression() { + return _requiresCompression; + } + public void constructAndSetLopsDataFlowProperties() throws HopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java index 4a630c0..c6c3016 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java @@ -266,7 +266,7 @@ public class FunctionCallGraph for( Hop h : hopsDAG ) { if( h instanceof FunctionOp ){ FunctionOp fop = (FunctionOp) h; - String lfkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); + String lfkey = fop.getFunctionKey(); //keep all function operators if( !_fCalls.containsKey(lfkey) ) _fCalls.put(lfkey, new ArrayList<FunctionOp>()); http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java index b813685..4a44317 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java @@ -473,7 +473,7 @@ public class InterProceduralAnalysis { //maintain counters and investigate functions if not seen so far FunctionOp fop = (FunctionOp) hop; - String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); + String fkey = fop.getFunctionKey(); if( fop.getFunctionType() == FunctionType.DML ) { @@ -527,7 +527,7 @@ public class InterProceduralAnalysis { ArrayList<DataIdentifier> inputVars = fstmt.getInputParams(); ArrayList<Hop> inputOps = fop.getInput(); - String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); + String fkey = fop.getFunctionKey(); for( int i=0; i<inputVars.size(); i++ ) { @@ -587,7 +587,7 @@ public class InterProceduralAnalysis { ArrayList<DataIdentifier> foutputOps = fstmt.getOutputParams(); String[] outputVars = fop.getOutputVariableNames(); - String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); + String fkey = fop.getFunctionKey(); try { @@ -650,7 +650,7 @@ public class InterProceduralAnalysis { ArrayList<DataIdentifier> foutputOps = fstmt.getOutputParams(); String[] outputVars = fop.getOutputVariableNames(); - String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); + String fkey = fop.getFunctionKey(); try { @@ -661,7 +661,7 @@ public class InterProceduralAnalysis if( di.getDataType()==DataType.MATRIX ) { - MatrixObject moOut = createOutputMatrix(-1, -1, -1); + MatrixObject moOut = createOutputMatrix(-1, -1, -1); callVars.put(pvarname, moOut); } } @@ -675,14 +675,14 @@ public class InterProceduralAnalysis private void extractFunctionCallEquivalentReturnStatistics( FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callVars ) throws HopsException { - String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); try { Hop input = fop.getInput().get(0); MatrixObject moOut = createOutputMatrix(input.getDim1(), input.getDim2(), -1); callVars.put(fop.getOutputVariableNames()[0], moOut); } catch( Exception ex ) { - throw new HopsException( "Failed to extract output statistics for unary function "+fkey+".", ex); + throw new HopsException( "Failed to extract output statistics " + + "for unary function "+fop.getFunctionKey()+".", ex); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 82eff52..acc9da9 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -91,8 +91,7 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteTransientWriteParentHandling() ); _dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize _dagRuleSet.add( new RewriteBlockSizeAndReblock() ); - _dagRuleSet.add( new RewriteCompressedReblock() ); - _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); + _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) @@ -110,6 +109,7 @@ public class ProgramRewriter _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding _sbRuleSet.add( new RewriteMergeBlockSequence() ); //dependency: remove branches } + _sbRuleSet.add( new RewriteCompressedReblock() ); if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS ) _sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock, merge blocks if( ConfigurationManager.getCompilerConfigFlag(ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS) ) @@ -118,7 +118,7 @@ public class ProgramRewriter _sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop) _sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes) if( OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE ) - _sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() ); + _sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() ); } // DYNAMIC REWRITES (which do require size information) http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java index c150a2b..7e4567d 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java @@ -20,63 +20,309 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; -import org.apache.sysml.hops.DataOp; +import org.apache.sysml.hops.AggBinaryOp; +import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataOpTypes; +import org.apache.sysml.hops.Hop.OpOp1; +import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.lops.Compression.CompressConfig; +import org.apache.sysml.lops.MMTSJ.MMTSJType; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.ForStatement; +import org.apache.sysml.parser.ForStatementBlock; +import org.apache.sysml.parser.FunctionStatement; +import org.apache.sysml.parser.FunctionStatementBlock; +import org.apache.sysml.parser.IfStatement; +import org.apache.sysml.parser.IfStatementBlock; +import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.parser.WhileStatement; +import org.apache.sysml.parser.WhileStatementBlock; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; /** * Rule: CompressedReblock: If config compressed.linalg is enabled, we - * inject compression hooks after pread of matrices w/ both dims > 1. + * inject compression directions after pread of matrices w/ both dims > 1 + * (i.e., multi-column matrices). In case of 'auto' compression, we apply + * compression if the datasize is known to exceed aggregate cluster memory, + * the matrix is used in loops, and all operations are supported over + * compressed matrices. */ -public class RewriteCompressedReblock extends HopRewriteRule +public class RewriteCompressedReblock extends StatementBlockRewriteRule { + private static final String TMP_PREFIX = "__cmtx"; @Override - public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) - throws HopsException + public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus sate) + throws HopsException { - if( roots == null ) - return null; + //check for inapplicable statement blocks + if( !HopRewriteUtils.isLastLevelStatementBlock(sb) + || sb.get_hops() == null ) + return Arrays.asList(sb); - boolean compress = ConfigurationManager.getDMLConfig() - .getBooleanValue(DMLConfig.COMPRESSED_LINALG); + //parse compression config + DMLConfig conf = ConfigurationManager.getDMLConfig(); + CompressConfig compress = CompressConfig.valueOf( + conf.getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase()); //perform compressed reblock rewrite - if( compress ) - for( Hop h : roots ) - rule_CompressedReblock(h); - - return roots; + if( compress.isEnabled() ) { + Hop.resetVisitStatus(sb.get_hops()); + for( Hop h : sb.get_hops() ) + injectCompressionDirective(h, compress, sb.getDMLProg()); + Hop.resetVisitStatus(sb.get_hops()); + } + return Arrays.asList(sb); } @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) - throws HopsException + public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) + throws HopsException { - //do nothing (ppred will never occur in predicate) - return root; + return sbs; } - - private void rule_CompressedReblock(Hop hop) + + private static void injectCompressionDirective(Hop hop, CompressConfig compress, DMLProgram prog) throws HopsException { - // Go to the source(s) of the DAG - for (Hop hi : hop.getInput()) { - if (!hi.isVisited()) - rule_CompressedReblock(hi); - } - - if( hop instanceof DataOp - && ((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTREAD - && hop.getDim1() > 1 && hop.getDim2() > 1 ) + if( hop.isVisited() || hop.requiresCompression() ) + return; + + // recursively process children + for( Hop hi : hop.getInput() ) + injectCompressionDirective(hi, compress, prog); + // check for compression conditions + if( compress == CompressConfig.TRUE && satisfiesCompressionCondition(hop) + || compress == CompressConfig.AUTO && satisfiesAutoCompressionCondition(hop, prog) ) { hop.setRequiresCompression(true); } - + hop.setVisited(); } + + private static boolean satisfiesCompressionCondition(Hop hop) { + return HopRewriteUtils.isData(hop, DataOpTypes.PERSISTENTREAD) + && hop.getDim1() > 1 && hop.getDim2() > 1; //multi-column matrix + } + + private static boolean satisfiesAutoCompressionCondition(Hop hop, DMLProgram prog) + throws HopsException + { + //check for basic compression condition + if( !(satisfiesCompressionCondition(hop) + && OptimizerUtils.isSparkExecutionMode()) ) + return false; + + //determine if data size exceeds aggregate cluster storage memory + double matrixPSize = OptimizerUtils.estimatePartitionedSizeExactSparsity( + hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz()); + double cacheSize = SparkExecutionContext.getDataMemoryBudget(true, true); + boolean outOfCore = matrixPSize > cacheSize; + + //determine if matrix is ultra sparse (and hence serialized) + double sparsity = OptimizerUtils.getSparsity(hop.getDim1(), hop.getDim2(), hop.getNnz()); + boolean ultraSparse = sparsity < MatrixBlock.ULTRA_SPARSITY_TURN_POINT; + + //determine if all operations are supported over compressed matrices, + //but conditionally only if all other conditions are met + if( hop.dimsKnown(true) && outOfCore && !ultraSparse ) { + //analyze program recursively, including called functions + ProbeStatus status = new ProbeStatus(hop.getHopID(), prog); + for( StatementBlock sb : prog.getStatementBlocks() ) + rAnalyzeProgram(sb, status); + + //applicable if used in loop (amortized compressed costs), + // no conditional updates in if-else branches + // and all operations are applicable (no decompression costs) + boolean ret = status.foundStart && status.usedInLoop + && !status.condUpdate && !status.nonApplicable; + if( LOG.isDebugEnabled() ) { + LOG.debug("Auto compression: "+ret+" (dimsKnown="+hop.dimsKnown(true) + + ", outOfCore="+outOfCore+", !ultraSparse="+!ultraSparse + +", foundStart="+status.foundStart+", usedInLoop="+status.foundStart + +", !condUpdate="+!status.condUpdate+", !nonApplicable="+!status.nonApplicable+")"); + } + return ret; + } + else if( LOG.isDebugEnabled() ) { + LOG.debug("Auto compression: false (dimsKnown="+hop.dimsKnown(true) + + ", outOfCore="+outOfCore+", !ultraSparse="+!ultraSparse+")"); + } + return false; + } + + private static void rAnalyzeProgram(StatementBlock sb, ProbeStatus status) + throws HopsException + { + if(sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock) sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for (StatementBlock csb : fstmt.getBody()) + rAnalyzeProgram(csb, status); + } + else if(sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + for (StatementBlock csb : wstmt.getBody()) + rAnalyzeProgram(csb, status); + if( wsb.variablesRead().containsAnyName(status.compMtx) ) + status.usedInLoop = true; + } + else if(sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + for (StatementBlock csb : istmt.getIfBody()) + rAnalyzeProgram(csb, status); + for (StatementBlock csb : istmt.getElseBody()) + rAnalyzeProgram(csb, status); + if( isb.variablesUpdated().containsAnyName(status.compMtx) ) + status.condUpdate = true; + } + else if(sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + for (StatementBlock csb : fstmt.getBody()) + rAnalyzeProgram(csb, status); + if( fsb.variablesRead().containsAnyName(status.compMtx) ) + status.usedInLoop = true; + } + else if( sb.get_hops() != null ) { //generic (last-level) + ArrayList<Hop> roots = sb.get_hops(); + Hop.resetVisitStatus(roots); + //process entire HOP DAG starting from the roots + for( Hop root : roots ) + rAnalyzeHopDag(root, status); + //remove temporary variables + status.compMtx.removeIf(n -> n.startsWith(TMP_PREFIX)); + Hop.resetVisitStatus(roots); + } + } + + private static void rAnalyzeHopDag(Hop current, ProbeStatus status) + throws HopsException + { + if( current.isVisited() ) + return; + + //process children recursively + for( Hop input : current.getInput() ) + rAnalyzeHopDag(input, status); + + //handle source persistent read + if( current.getHopID() == status.startHopID ) { + status.compMtx.add(getTmpName(current)); + status.foundStart = true; + } + + //handle individual hops + //a) handle function calls + if( current instanceof FunctionOp + && hasCompressedInput(current, status) ) + { + //TODO handle of functions in a more fine-grained manner + //to cover special cases multiple calls where compressed + //inputs might occur for different input parameters + + FunctionOp fop = (FunctionOp) current; + String fkey = fop.getFunctionKey(); + if( !status.procFn.contains(fkey) ) { + //memoization to avoid redundant analysis and recursive calls + status.procFn.add(fkey); + //map inputs to function inputs + FunctionStatementBlock fsb = status.prog.getFunctionStatementBlock(fkey); + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + ProbeStatus status2 = new ProbeStatus(status); + for(int i=0; i<fop.getInput().size(); i++) + if( status.compMtx.contains(getTmpName(fop.getInput().get(i))) ) + status2.compMtx.add(fstmt.getInputParams().get(i).getName()); + //analyze function and merge meta info + rAnalyzeProgram(fsb, status2); + status.foundStart |= status2.foundStart; + status.usedInLoop |= status2.usedInLoop; + status.condUpdate |= status2.condUpdate; + status.nonApplicable |= status2.nonApplicable; + //map function outputs to outputs + String[] outputs = fop.getOutputVariableNames(); + for( int i=0; i<outputs.length; i++ ) + if( status2.compMtx.contains(fstmt.getOutputParams().get(i).getName()) ) + status.compMtx.add(outputs[i]); + } + } + //b) handle transient reads and writes (name mapping) + else if( HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTWRITE) + && status.compMtx.contains(getTmpName(current.getInput().get(0)))) + status.compMtx.add(current.getName()); + else if( HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTREAD) + && status.compMtx.contains(current.getName()) ) + status.compMtx.add(getTmpName(current)); + //c) handle applicable operations + else if( hasCompressedInput(current, status) ) { + boolean compUCOut = //valid with uncompressed outputs + (current instanceof AggBinaryOp && current.getDim2()<= current.getColsInBlock() //tsmm + && ((AggBinaryOp)current).checkTransposeSelf()==MMTSJType.LEFT) + || (current instanceof AggBinaryOp && (current.getDim1()==1 || current.getDim2()==1)) //mvmm + || (HopRewriteUtils.isTransposeOperation(current) && current.getParent().size()==1 + && current.getParent().get(0) instanceof AggBinaryOp + && (current.getParent().get(0).getDim1()==1 || current.getParent().get(0).getDim2()==1)) + || HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX); + boolean compCOut = //valid with compressed outputs + HopRewriteUtils.isBinaryMatrixScalarOperation(current) + || HopRewriteUtils.isBinary(current, OpOp2.CBIND); + boolean metaOp = HopRewriteUtils.isUnary(current, OpOp1.NROW, OpOp1.NCOL); + status.nonApplicable |= !(compUCOut || compCOut || metaOp); + if( compCOut ) + status.compMtx.add(getTmpName(current)); + } + + current.setVisited(); + } + + private static String getTmpName(Hop hop) { + return TMP_PREFIX + hop.getHopID(); + } + + private static boolean hasCompressedInput(Hop hop, ProbeStatus status) { + if( status.compMtx.isEmpty() ) + return false; + for( Hop input : hop.getInput() ) + if( status.compMtx.contains(getTmpName(input)) ) + return true; + return false; + } + + private static class ProbeStatus { + private final long startHopID; + private final DMLProgram prog; + private boolean foundStart = false; + private boolean usedInLoop = false; + private boolean condUpdate = false; + private boolean nonApplicable = false; + private HashSet<String> procFn = new HashSet<>(); + private HashSet<String> compMtx = new HashSet<>(); + public ProbeStatus(long hopID, DMLProgram p) { + startHopID = hopID; + prog = p; + } + public ProbeStatus(ProbeStatus status) { + startHopID = status.startHopID; + prog = status.prog; + foundStart = status.foundStart; + usedInLoop = status.usedInLoop; + condUpdate = status.condUpdate; + nonApplicable = status.nonApplicable; + procFn.addAll(status.procFn); + } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index 1f49500..f8d0a1b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -369,7 +369,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //create additional cut by rewriting both hop dags int pos = HopRewriteUtils.getChildReferencePos(hop, c); HopRewriteUtils.removeChildReferenceByPos(hop, c, pos); - HopRewriteUtils.addChildReference(hop, tread, pos); + HopRewriteUtils.addChildReference(hop, tread, pos); //update live in and out of new statement block (for piggybacking) DataIdentifier diVar = new DataIdentifier(varname); http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/lops/Compression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Compression.java b/src/main/java/org/apache/sysml/lops/Compression.java index 75feef4..65293ef 100644 --- a/src/main/java/org/apache/sysml/lops/Compression.java +++ b/src/main/java/org/apache/sysml/lops/Compression.java @@ -27,8 +27,17 @@ import org.apache.sysml.parser.Expression.ValueType; public class Compression extends Lop { - public static final String OPCODE = "compress"; - + public static final String OPCODE = "compress"; + + public enum CompressConfig { + TRUE, + FALSE, + AUTO; + public boolean isEnabled() { + return this == TRUE || this == AUTO; + } + } + public Compression(Lop input, DataType dt, ValueType vt, ExecType et) throws LopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/parser/VariableSet.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/VariableSet.java b/src/main/java/org/apache/sysml/parser/VariableSet.java index 16ad08b..c2a1151 100644 --- a/src/main/java/org/apache/sysml/parser/VariableSet.java +++ b/src/main/java/org/apache/sysml/parser/VariableSet.java @@ -56,6 +56,11 @@ public class VariableSet return _variables.containsKey(name); } + public boolean containsAnyName(Set<String> names){ + return _variables.keySet().stream() + .anyMatch(n -> names.contains(n)); + } + public DataIdentifier getVariable(String name){ return _variables.get(name); } http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java index 1aae331..0bdd59f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java @@ -540,7 +540,7 @@ public class OptTreeConverter FunctionOp fhop = (FunctionOp) hop; String fname = fhop.getFunctionName(); String fnspace = fhop.getFunctionNamespace(); - String fKey = DMLProgram.constructFunctionKey(fnspace, fname); + String fKey = fhop.getFunctionKey(); Object[] prog = _hlMap.getRootProgram(); OptNode node = new OptNode(NodeType.FUNCCALL); @@ -558,16 +558,13 @@ public class OptTreeConverter if( !memo.contains(fKey) ) { memo.add(fKey); - int len = fs.getBody().size(); - for( int i=0; i<fpb.getChildBlocks().size() && i<len; i++ ) - { + for( int i=0; i<fpb.getChildBlocks().size() && i<len; i++ ) { ProgramBlock lpb = fpb.getChildBlocks().get(i); StatementBlock lsb = fs.getBody().get(i); node.addChild( rCreateAbstractOptNode(lsb, lpb, vars, false, memo) ); } - - memo.remove(fKey); + memo.remove(fKey); } else node.addParam(ParamType.RECURSIVE_CALL, "true"); http://git-wip-us.apache.org/repos/asf/systemml/blob/f991e4b4/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreePlanChecker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreePlanChecker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreePlanChecker.java index db0f297..2342b76 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreePlanChecker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreePlanChecker.java @@ -189,11 +189,9 @@ public class OptTreePlanChecker return; //process functionop - if( hop instanceof FunctionOp ) - { + if( hop instanceof FunctionOp ) { FunctionOp fop = (FunctionOp) hop; - String key = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(),fop.getFunctionName()); - memo.put(key, fop); + memo.put(fop.getFunctionKey(), fop); } //process children