Repository: systemml Updated Branches: refs/heads/master 2c9694dec -> ac1cf093a
[SYSTEMML-1808] [SYSTEMML-1658] Visualize Hop DAG for explaining the optimizer - Also added an utility to print java output in notebook. - Fixed a bug in dmlFromResource. Closes #596. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ac1cf093 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ac1cf093 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ac1cf093 Branch: refs/heads/master Commit: ac1cf093ad0b47cb6a0f0d48c4deb276b4ae1fa6 Parents: 2c9694d Author: Niketan Pansare <[email protected]> Authored: Thu Aug 3 09:06:24 2017 -0800 Committer: Niketan Pansare <[email protected]> Committed: Thu Aug 3 10:06:24 2017 -0700 ---------------------------------------------------------------------- .../apache/sysml/api/mlcontext/MLContext.java | 25 +- .../sysml/api/mlcontext/MLContextUtil.java | 108 +++++++ .../sysml/api/mlcontext/ScriptExecutor.java | 46 ++- .../context/SparkExecutionContext.java | 2 +- .../sysml/runtime/instructions/Instruction.java | 20 ++ .../java/org/apache/sysml/utils/Explain.java | 301 +++++++++++++++++++ src/main/python/systemml/mlcontext.py | 100 +++++- .../scala/org/apache/sysml/api/ml/Utils.scala | 61 ++++ 8 files changed, 646 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java index b35faa6..f74d593 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java @@ -46,7 +46,6 @@ import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.utils.Explain.ExplainType; import org.apache.sysml.utils.MLContextProxy; - /** * The MLContext API offers programmatic access to SystemML on Spark from * languages such as Scala, Java, and Python. @@ -287,6 +286,8 @@ public class MLContext { public void resetConfig() { MLContextUtil.setDefaultConfig(); } + + /** * Set configuration property, such as @@ -305,7 +306,8 @@ public class MLContext { throw new MLContextException(e); } } - + + /** * Execute a DML or PYDML Script. * @@ -357,6 +359,16 @@ public class MLContext { throw new MLContextException("Exception when executing script", e); } } + + /** + * Sets the script that is being executed + * + * @param executionScript + * script that is being executed + */ + public void setExecutionScript(Script executionScript) { + this.executionScript = executionScript; + } /** * Set SystemML configuration based on a configuration file. @@ -489,6 +501,15 @@ public class MLContext { } /** + * Whether or not the "force" GPU mode is enabled. + * + * @return true if enabled, false otherwise + */ + public boolean isForceGPU() { + return this.forceGPU; + } + + /** * Used internally by MLContextProxy. * */ http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java index 2c9566c..51d38a5 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java @@ -23,6 +23,7 @@ import java.io.File; import java.io.FileNotFoundException; import java.net.URL; import java.util.ArrayList; +import java.util.Date; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; @@ -35,6 +36,7 @@ import javax.xml.parsers.DocumentBuilderFactory; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.text.WordUtils; +import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -52,8 +54,11 @@ import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.CompilerConfig.ConfigType; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.parser.LanguageException; import org.apache.sysml.parser.ParseException; import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.ForProgramBlock; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.IfProgramBlock; @@ -63,6 +68,8 @@ import org.apache.sysml.runtime.controlprogram.ProgramBlock; import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.instructions.cp.Data; @@ -73,6 +80,7 @@ import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.utils.Explain; import org.apache.sysml.utils.MLContextProxy; import org.w3c.dom.Document; import org.w3c.dom.Node; @@ -83,6 +91,106 @@ import org.w3c.dom.NodeList; * */ public final class MLContextUtil { + + /** + * Get HOP DAG in dot format for a DML or PYDML Script. + * + * @param mlCtx + * MLContext object. + * @param script + * The DML or PYDML Script object to execute. + * @param lines + * Only display the hops that have begin and end line number + * equals to the given integers. + * @param performHOPRewrites + * should perform static rewrites, perform + * intra-/inter-procedural analysis to propagate size information + * into functions and apply dynamic rewrites + * @param withSubgraph + * If false, the dot graph will be created without subgraphs for + * statement blocks. + * @return hop DAG in dot format + * @throws LanguageException + * if error occurs + * @throws DMLRuntimeException + * if error occurs + * @throws HopsException + * if error occurs + */ + public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, + boolean performHOPRewrites, boolean withSubgraph) throws HopsException, DMLRuntimeException, + LanguageException { + return getHopDAG(mlCtx, script, lines, null, performHOPRewrites, withSubgraph); + } + + /** + * Get HOP DAG in dot format for a DML or PYDML Script. + * + * @param mlCtx + * MLContext object. + * @param script + * The DML or PYDML Script object to execute. + * @param lines + * Only display the hops that have begin and end line number + * equals to the given integers. + * @param newConf + * Spark Configuration. + * @param performHOPRewrites + * should perform static rewrites, perform + * intra-/inter-procedural analysis to propagate size information + * into functions and apply dynamic rewrites + * @param withSubgraph + * If false, the dot graph will be created without subgraphs for + * statement blocks. + * @return hop DAG in dot format + * @throws LanguageException + * if error occurs + * @throws DMLRuntimeException + * if error occurs + * @throws HopsException + * if error occurs + */ + public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, SparkConf newConf, + boolean performHOPRewrites, boolean withSubgraph) throws HopsException, DMLRuntimeException, + LanguageException { + SparkConf oldConf = mlCtx.getSparkSession().sparkContext().getConf(); + SparkExecutionContext.SparkClusterConfig systemmlConf = SparkExecutionContext.getSparkClusterConfig(); + long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory(); + try { + if (newConf != null) { + systemmlConf.analyzeSparkConfiguation(newConf); + InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory")); + } + ScriptExecutor scriptExecutor = new ScriptExecutor(); + scriptExecutor.setExecutionType(mlCtx.getExecutionType()); + scriptExecutor.setGPU(mlCtx.isGPU()); + scriptExecutor.setForceGPU(mlCtx.isForceGPU()); + scriptExecutor.setInit(mlCtx.isInitBeforeExecution()); + if (mlCtx.isInitBeforeExecution()) { + mlCtx.setInitBeforeExecution(false); + } + scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable()); + + Long time = new Long((new Date()).getTime()); + if ((script.getName() == null) || (script.getName().equals(""))) { + script.setName(time.toString()); + } + + mlCtx.setExecutionScript(script); + scriptExecutor.compile(script, performHOPRewrites); + Explain.reset(); + // To deal with potential Py4J issues + lines = lines.size() == 1 && lines.get(0) == -1 ? new ArrayList<Integer>() : lines; + return Explain.getHopDAG(scriptExecutor.dmlProgram, lines, withSubgraph); + } catch (RuntimeException e) { + throw new MLContextException("Exception when compiling script", e); + } finally { + if (newConf != null) { + systemmlConf.analyzeSparkConfiguation(oldConf); + InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory); + } + } + } /** * Basic data types supported by the MLContext API http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java index 467e94e..7e78891 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -263,9 +263,16 @@ public class ScriptExecutor { DMLScript.USE_ACCELERATOR = oldGPU; DMLScript.STATISTICS_COUNT = DMLOptions.defaultOptions.statsCount; } - + + public void compile(Script script) { + compile(script, true); + } + /** - * Execute a DML or PYDML script. This is broken down into the following + * Compile a DML or PYDML script. This will help analysis of DML programs + * that have dynamic recompilation flag set to false without actually executing it. + * + * This is broken down into the following * primary methods: * * <ol> @@ -283,16 +290,14 @@ public class ScriptExecutor { * <li>{@link #countCompiledMRJobsAndSparkInstructions()}</li> * <li>{@link #initializeCachingAndScratchSpace()}</li> * <li>{@link #cleanupRuntimeProgram()}</li> - * <li>{@link #createAndInitializeExecutionContext()}</li> - * <li>{@link #executeRuntimeProgram()}</li> - * <li>{@link #cleanupAfterExecution()}</li> * </ol> * * @param script - * the DML or PYDML script to execute - * @return the results as a MLResults object + * the DML or PYDML script to compile + * @param performHOPRewrites + * should perform static rewrites, perform intra-/inter-procedural analysis to propagate size information into functions and apply dynamic rewrites */ - public MLResults execute(Script script) { + public void compile(Script script, boolean performHOPRewrites) { // main steps in script execution setup(script); @@ -303,7 +308,8 @@ public class ScriptExecutor { liveVariableAnalysis(); validateScript(); constructHops(); - rewriteHops(); + if(performHOPRewrites) + rewriteHops(); rewritePersistentReadsAndWrites(); constructLops(); generateRuntimeProgram(); @@ -315,6 +321,28 @@ public class ScriptExecutor { if (statistics) { Statistics.stopCompileTimer(); } + } + + + /** + * Execute a DML or PYDML script. This is broken down into the following + * primary methods: + * + * <ol> + * <li>{@link #compile(Script)}</li> + * <li>{@link #createAndInitializeExecutionContext()}</li> + * <li>{@link #executeRuntimeProgram()}</li> + * <li>{@link #cleanupAfterExecution()}</li> + * </ol> + * + * @param script + * the DML or PYDML script to execute + * @return the results as a MLResults object + */ + public MLResults execute(Script script) { + + // main steps in script execution + compile(script); try { createAndInitializeExecutionContext(); http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index 6f2f766..d1ff7d8 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -1378,7 +1378,7 @@ public class SparkExecutionContext extends ExecutionContext * degree of parallelism. This configuration abstracts legacy (< Spark 1.6) and current * configurations and provides a unified view. */ - private static class SparkClusterConfig + public static class SparkClusterConfig { //broadcasts are stored in mem-and-disk in data space, this config //defines the fraction of data space to be used as broadcast budget http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java b/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java index 6db8c7f..374f81c 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java @@ -63,6 +63,26 @@ public abstract class Instruction protected int beginCol = -1; protected int endCol = -1; + public String getFilename() { + return filename; + } + + public int getBeginLine() { + return beginLine; + } + + public int getEndLine() { + return endLine; + } + + public int getBeginColumn() { + return beginCol; + } + + public int getEndColumn() { + return endCol; + } + public void setType (INSTRUCTION_TYPE tp ) { type = tp; } http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/utils/Explain.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java index a2e843a..01b59f7 100644 --- a/src/main/java/org/apache/sysml/utils/Explain.java +++ b/src/main/java/org/apache/sysml/utils/Explain.java @@ -26,10 +26,16 @@ import java.util.HashSet; import java.util.Map; import java.util.Map.Entry; +import org.apache.sysml.hops.AggBinaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; @@ -266,6 +272,50 @@ public class Explain return sb.toString(); } + + public static String getHopDAG(DMLProgram prog, ArrayList<Integer> lines, boolean withSubgraph) + throws HopsException, DMLRuntimeException, LanguageException { + StringBuilder sb = new StringBuilder(); + StringBuilder nodes = new StringBuilder(); + + // create header + sb.append("digraph {"); + + // Explain functions (if exists) + if (prog.hasFunctionStatementBlocks()) { + + // show function call graph + // FunctionCallGraph fgraph = new FunctionCallGraph(prog); + // sb.append(explainFunctionCallGraph(fgraph, new HashSet<String>(), + // null, 3)); + + // show individual functions + for (String namespace : prog.getNamespaces().keySet()) { + for (String fname : prog.getFunctionStatementBlocks(namespace).keySet()) { + FunctionStatementBlock fsb = prog.getFunctionStatementBlock(namespace, fname); + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + String fkey = DMLProgram.constructFunctionKey(namespace, fname); + + if (!(fstmt instanceof ExternalFunctionStatement)) { + addSubGraphHeader(sb, withSubgraph); + for (StatementBlock current : fstmt.getBody()) + sb.append(getHopDAG(current, nodes, lines, withSubgraph)); + String label = "FUNCTION " + fkey + " recompile=" + fsb.isRecompileOnce() + "\n"; + addSubGraphFooter(sb, withSubgraph, label); + } + } + } + } + + // Explain main program + for (StatementBlock sblk : prog.getStatementBlocks()) + sb.append(getHopDAG(sblk, nodes, lines, withSubgraph)); + + sb.append(nodes); + sb.append("rankdir = \"BT\"\n"); + sb.append("}\n"); + return sb.toString(); + } public static String explain( Program rtprog ) throws HopsException { return explain(rtprog, null); @@ -466,6 +516,128 @@ public class Explain ////////////// // internal explain HOPS + private static int clusterID = 0; + + public static void reset() { + clusterID = 0; + } + + private static void addSubGraphHeader(StringBuilder builder, boolean withSubgraph) { + if (withSubgraph) { + builder.append("subgraph cluster_" + (clusterID++) + " {\n"); + } + } + + private static void addSubGraphFooter(StringBuilder builder, boolean withSubgraph, String label) { + if (withSubgraph) { + builder.append("label = \"" + label + "\";\n"); + builder.append("}\n"); + } + } + + private static StringBuilder getHopDAG(StatementBlock sb, StringBuilder nodes, ArrayList<Integer> lines, + boolean withSubgraph) throws HopsException, DMLRuntimeException { + StringBuilder builder = new StringBuilder(); + + if (sb instanceof WhileStatementBlock) { + addSubGraphHeader(builder, withSubgraph); + + WhileStatementBlock wsb = (WhileStatementBlock) sb; + String label = null; + if (!wsb.getUpdateInPlaceVars().isEmpty()) + label = "WHILE (lines " + wsb.getBeginLine() + "-" + wsb.getEndLine() + ") in-place=" + + wsb.getUpdateInPlaceVars().toString() + ""; + else + label = "WHILE (lines " + wsb.getBeginLine() + "-" + wsb.getEndLine() + ")"; + // TODO: Don't show predicate hops for now + // builder.append(explainHop(wsb.getPredicateHops())); + + WhileStatement ws = (WhileStatement) sb.getStatement(0); + for (StatementBlock current : ws.getBody()) + builder.append(getHopDAG(current, nodes, lines, withSubgraph)); + + addSubGraphFooter(builder, withSubgraph, label); + } else if (sb instanceof IfStatementBlock) { + addSubGraphHeader(builder, withSubgraph); + IfStatementBlock ifsb = (IfStatementBlock) sb; + String label = "IF (lines " + ifsb.getBeginLine() + "-" + ifsb.getEndLine() + ")"; + // TODO: Don't show predicate hops for now + // builder.append(explainHop(ifsb.getPredicateHops(), level+1)); + + IfStatement ifs = (IfStatement) sb.getStatement(0); + for (StatementBlock current : ifs.getIfBody()) { + builder.append(getHopDAG(current, nodes, lines, withSubgraph)); + addSubGraphFooter(builder, withSubgraph, label); + } + if (!ifs.getElseBody().isEmpty()) { + addSubGraphHeader(builder, withSubgraph); + label = "ELSE (lines " + ifsb.getBeginLine() + "-" + ifsb.getEndLine() + ")"; + + for (StatementBlock current : ifs.getElseBody()) + builder.append(getHopDAG(current, nodes, lines, withSubgraph)); + addSubGraphFooter(builder, withSubgraph, label); + } + } else if (sb instanceof ForStatementBlock) { + ForStatementBlock fsb = (ForStatementBlock) sb; + addSubGraphHeader(builder, withSubgraph); + String label = ""; + if (sb instanceof ParForStatementBlock) { + if (!fsb.getUpdateInPlaceVars().isEmpty()) + label = "PARFOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ") in-place=" + + fsb.getUpdateInPlaceVars().toString() + ""; + else + label = "PARFOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")"; + } else { + if (!fsb.getUpdateInPlaceVars().isEmpty()) + label = "FOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ") in-place=" + + fsb.getUpdateInPlaceVars().toString() + ""; + else + label = "FOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")"; + } + // TODO: Don't show predicate hops for now + // if (fsb.getFromHops() != null) + // builder.append(explainHop(fsb.getFromHops(), level+1)); + // if (fsb.getToHops() != null) + // builder.append(explainHop(fsb.getToHops(), level+1)); + // if (fsb.getIncrementHops() != null) + // builder.append(explainHop(fsb.getIncrementHops(), level+1)); + + ForStatement fs = (ForStatement) sb.getStatement(0); + for (StatementBlock current : fs.getBody()) + builder.append(getHopDAG(current, nodes, lines, withSubgraph)); + addSubGraphFooter(builder, withSubgraph, label); + + } else if (sb instanceof FunctionStatementBlock) { + FunctionStatement fsb = (FunctionStatement) sb.getStatement(0); + addSubGraphHeader(builder, withSubgraph); + String label = "Function (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")"; + for (StatementBlock current : fsb.getBody()) + builder.append(getHopDAG(current, nodes, lines, withSubgraph)); + addSubGraphFooter(builder, withSubgraph, label); + } else { + // For generic StatementBlock + if (sb.requiresRecompilation()) { + addSubGraphHeader(builder, withSubgraph); + } + ArrayList<Hop> hopsDAG = sb.get_hops(); + if (hopsDAG != null && !hopsDAG.isEmpty()) { + Hop.resetVisitStatus(hopsDAG); + for (Hop hop : hopsDAG) + builder.append(getHopDAG(hop, nodes, lines, withSubgraph)); + Hop.resetVisitStatus(hopsDAG); + } + + if (sb.requiresRecompilation()) { + builder.append("style=filled;\n"); + builder.append("color=lightgrey;\n"); + String label = "(lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ") [recompile=" + + sb.requiresRecompilation() + "]"; + addSubGraphFooter(builder, withSubgraph, label); + } + } + return builder; + } + private static String explainStatementBlock(StatementBlock sb, int level) throws HopsException, DMLRuntimeException { @@ -636,6 +808,134 @@ public class Explain return sb.toString(); } + + private static boolean isInRange(Hop hop, ArrayList<Integer> lines) { + boolean isInRange = lines.size() == 0 ? true : false; + for (int lineNum : lines) { + if (hop.getBeginLine() == lineNum && lineNum == hop.getEndLine()) { + return true; + } + } + return isInRange; + } + + private static StringBuilder getHopDAG(Hop hop, StringBuilder nodes, ArrayList<Integer> lines, boolean withSubgraph) + throws DMLRuntimeException { + StringBuilder sb = new StringBuilder(); + if (hop.isVisited() || (!SHOW_LITERAL_HOPS && hop instanceof LiteralOp)) + return sb; + + for (Hop input : hop.getInput()) { + if ((SHOW_LITERAL_HOPS || !(input instanceof LiteralOp)) && isInRange(hop, lines)) { + String edgeLabel = showMem(input.getOutputMemEstimate(), true); + sb.append("h" + input.getHopID() + " -> h" + hop.getHopID() + " [label=\"" + edgeLabel + "\"];\n"); + } + } + for (Hop input : hop.getInput()) + sb.append(getHopDAG(input, nodes, lines, withSubgraph)); + + if (isInRange(hop, lines)) { + nodes.append("h" + hop.getHopID() + "[label=\"" + getNodeLabel(hop) + "\", " + "shape=\"" + + getNodeShape(hop) + "\", color=\"" + getNodeColor(hop) + "\", tooltip=\"" + getNodeToolTip(hop) + + "\"];\n"); + } + hop.setVisited(); + + return sb; + } + + private static String getNodeLabel(Hop hop) { + StringBuilder sb = new StringBuilder(); + sb.append(hop.getOpString()); + if (hop instanceof AggBinaryOp) { + AggBinaryOp aggBinOp = (AggBinaryOp) hop; + if (aggBinOp.getMMultMethod() != null) + sb.append(" " + aggBinOp.getMMultMethod().name() + " "); + } + // data flow properties + if (SHOW_DATA_FLOW_PROPERTIES) { + if (hop.requiresReblock() && hop.requiresCheckpoint()) + sb.append(", rblk,chkpt"); + else if (hop.requiresReblock()) + sb.append(", rblk"); + else if (hop.requiresCheckpoint()) + sb.append(", chkpt"); + } + if (hop.getFilename() == null) { + sb.append("[" + hop.getBeginLine() + ":" + hop.getBeginColumn() + "-" + hop.getEndLine() + ":" + + hop.getEndColumn() + "]"); + } else { + sb.append("[" + hop.getFilename() + " " + hop.getBeginLine() + ":" + hop.getBeginColumn() + "-" + + hop.getEndLine() + ":" + hop.getEndColumn() + "]"); + } + + if (hop.getUpdateType().isInPlace()) + sb.append("," + hop.getUpdateType().toString().toLowerCase()); + return sb.toString(); + } + + private static String getNodeToolTip(Hop hop) { + StringBuilder sb = new StringBuilder(); + if (hop.getExecType() != null) { + sb.append(hop.getExecType().name()); + } + sb.append("[" + hop.getDim1() + " X " + hop.getDim2() + "], nnz=" + hop.getNnz()); + sb.append(", mem= [in="); + sb.append(showMem(hop.getInputMemEstimate(), false)); + sb.append(", inter="); + sb.append(showMem(hop.getIntermediateMemEstimate(), false)); + sb.append(", out="); + sb.append(showMem(hop.getOutputMemEstimate(), false)); + sb.append(" -> "); + sb.append(showMem(hop.getMemEstimate(), true)); + sb.append("]"); + return sb.toString(); + } + + private static String getNodeShape(Hop hop) { + String shape = "octagon"; + if (hop.getExecType() != null) { + switch (hop.getExecType()) { + case CP: + shape = "ellipse"; + break; + case SPARK: + shape = "box"; + break; + case GPU: + shape = "trapezium"; + break; + case MR: + shape = "parallelogram"; + break; + default: + shape = "octagon"; + break; + } + } + return shape; + } + + private static String getNodeColor(Hop hop) { + if (hop instanceof DataOp) { + DataOp dOp = (DataOp) hop; + if (dOp.getDataOpType() == DataOpTypes.PERSISTENTREAD || dOp.getDataOpType() == DataOpTypes.TRANSIENTREAD) { + return "wheat2"; + } else if (dOp.getDataOpType() == DataOpTypes.PERSISTENTWRITE + || dOp.getDataOpType() == DataOpTypes.TRANSIENTWRITE) { + return "wheat4"; + } + } else if (hop instanceof AggBinaryOp) { + return "orangered2"; + } else if (hop instanceof BinaryOp) { + return "royalblue2"; + } else if (hop instanceof ReorgOp) { + return "green"; + } else if (hop instanceof UnaryOp) { + return "yellow"; + } + return "black"; + } ////////////// // internal explain CNODE @@ -867,6 +1167,7 @@ public class Explain sb.append( offsetInst ); sb.append( tmp ); + sb.append( '\n' ); } http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/python/systemml/mlcontext.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mlcontext.py b/src/main/python/systemml/mlcontext.py index 1e79648..b5a6bf9 100644 --- a/src/main/python/systemml/mlcontext.py +++ b/src/main/python/systemml/mlcontext.py @@ -19,7 +19,11 @@ # #------------------------------------------------------------- -__all__ = ['MLResults', 'MLContext', 'Script', 'dml', 'pydml', 'dmlFromResource', 'pydmlFromResource', 'dmlFromFile', 'pydmlFromFile', 'dmlFromUrl', 'pydmlFromUrl', '_java2py', 'Matrix'] +# Methods to create Script object +script_factory_methods = [ 'dml', 'pydml', 'dmlFromResource', 'pydmlFromResource', 'dmlFromFile', 'pydmlFromFile', 'dmlFromUrl', 'pydmlFromUrl' ] +# Utility methods +util_methods = [ 'jvm_stdout', '_java2py', 'getHopDAG' ] +__all__ = ['MLResults', 'MLContext', 'Script', 'Matrix' ] + script_factory_methods + util_methods import os @@ -27,13 +31,16 @@ try: import py4j.java_gateway from py4j.java_gateway import JavaObject from pyspark import SparkContext + from pyspark.conf import SparkConf import pyspark.mllib.common except ImportError: raise ImportError('Unable to import `pyspark`. Hint: Make sure you are running with PySpark.') from .converters import * from .classloader import * +import threading, time +_loadedSystemML = False def _get_spark_context(): """ Internal method to get already initialized SparkContext. @@ -44,10 +51,93 @@ def _get_spark_context(): SparkContext """ if SparkContext._active_spark_context is not None: - return SparkContext._active_spark_context + sc = SparkContext._active_spark_context + if not _loadedSystemML: + createJavaObject(sc, 'dummy') + _loadedSystemML = True + return sc else: raise Exception('Expected spark context to be created.') +# This is useful utility class to get the output of the driver JVM from within a Jupyter notebook +# Example usage: +# with jvm_stdout(): +# ml.execute(script) +class jvm_stdout(object): + """ + This is useful utility class to get the output of the driver JVM from within a Jupyter notebook + + Parameters + ---------- + parallel_flush: boolean + Should flush the stdout in parallel + """ + def __init__(self, parallel_flush=False): + self.util = SparkContext._active_spark_context._jvm.org.apache.sysml.api.ml.Utils() + self.parallel_flush = parallel_flush + self.t = threading.Thread(target=self.flush_stdout) + self.stop = False + + def flush_stdout(self): + while not self.stop: + time.sleep(1) # flush stdout every 1 second + str = self.util.flushStdOut() + if str != '': + str = str[:-1] if str.endswith('\n') else str + print(str) + + def __enter__(self): + self.util.startRedirectStdOut() + if self.parallel_flush: + self.t.start() + + def __exit__(self, *args): + if self.parallel_flush: + self.stop = True + self.t.join() + print(self.util.stopRedirectStdOut()) + + +def getHopDAG(ml, script, lines=None, conf=None, apply_rewrites=True, with_subgraph=False): + """ + Compile a DML / PyDML script. + + Parameters + ---------- + ml: MLContext instance + MLContext instance. + + script: Script instance + Script instance defined with the appropriate input and output variables. + + lines: list of integers + Optional: only display the hops that have begin and end line number equals to the given integers. + + conf: SparkConf instance + Optional spark configuration + + apply_rewrites: boolean + If True, perform static rewrites, perform intra-/inter-procedural analysis to propagate size information into functions and apply dynamic rewrites + + with_subgraph: boolean + If False, the dot graph will be created without subgraphs for statement blocks. + + Returns + ------- + hopDAG: string + hop DAG in dot format + """ + if not isinstance(script, Script): + raise ValueError("Expected script to be an instance of Script") + scriptString = script.scriptString + script_java = script.script_java + lines = [ int(x) for x in lines ] if lines is not None else [int(-1)] + sc = _get_spark_context() + if conf is not None: + hopDAG = sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, script_java, lines, conf._jconf, apply_rewrites, with_subgraph) + else: + hopDAG = sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, script_java, lines, apply_rewrites, with_subgraph) + return hopDAG def dml(scriptString): """ @@ -330,9 +420,9 @@ class Script(object): self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString) elif scriptFormat == "file" and self.scriptType == "pydml": self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString) - elif scriptFormat == "file" and self.scriptType == "dml": + elif isResource and self.scriptType == "dml": self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(scriptString) - elif scriptFormat == "file" and self.scriptType == "pydml": + elif isResource and self.scriptType == "pydml": self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(scriptString) elif scriptFormat == "string" and self.scriptType == "dml": self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString) @@ -605,7 +695,7 @@ class MLContext(object): def __repr__(self): return "MLContext" - + def execute(self, script): """ Execute a DML / PyDML script. http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/scala/org/apache/sysml/api/ml/Utils.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/ml/Utils.scala b/src/main/scala/org/apache/sysml/api/ml/Utils.scala index da3edf5..a804f64 100644 --- a/src/main/scala/org/apache/sysml/api/ml/Utils.scala +++ b/src/main/scala/org/apache/sysml/api/ml/Utils.scala @@ -18,8 +18,69 @@ */ package org.apache.sysml.api.ml +import org.apache.spark.api.java.JavaPairRDD +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; + +object Utils { + val originalOut = System.out + val originalErr = System.err +} class Utils { def checkIfFileExists(filePath:String):Boolean = { return org.apache.sysml.runtime.util.MapReduceTool.existsFileOnHDFS(filePath) } + + // -------------------------------------------------------------------------------- + // Simple utility function to print the information about our binary blocked format + def getBinaryBlockInfo(binaryBlocks:JavaPairRDD[MatrixIndexes, MatrixBlock]):String = { + val sb = new StringBuilder + var partitionIndex = 0 + for(str <- binaryBlocks.rdd.mapPartitions(binaryBlockIteratorToString(_), true).collect) { + sb.append("-------------------------------------\n") + sb.append("Partition " + partitionIndex + ":\n") + sb.append(str) + partitionIndex = partitionIndex + 1 + } + sb.append("-------------------------------------\n") + return sb.toString() + } + def binaryBlockIteratorToString(it: Iterator[(MatrixIndexes, MatrixBlock)]): Iterator[String] = { + val sb = new StringBuilder + for(entry <- it) { + val mi = entry._1 + val mb = entry._2 + sb.append(mi.toString); + sb.append(" sparse? = "); + sb.append(mb.isInSparseFormat()); + if(mb.isUltraSparse) + sb.append(" (ultra-sparse)") + sb.append(", nonzeros = "); + sb.append(mb.getNonZeros); + sb.append(", dimensions = "); + sb.append(mb.getNumRows); + sb.append(" X "); + sb.append(mb.getNumColumns); + sb.append("\n"); + } + List[String](sb.toString).iterator + } + val baos = new java.io.ByteArrayOutputStream() + val baes = new java.io.ByteArrayOutputStream() + def startRedirectStdOut():Unit = { + System.setOut(new java.io.PrintStream(baos)); + System.setErr(new java.io.PrintStream(baes)); + } + def flushStdOut():String = { + val ret = baos.toString() + baes.toString() + baos.reset(); baes.reset() + return ret + } + def stopRedirectStdOut():String = { + val ret = baos.toString() + baes.toString() + System.setOut(Utils.originalOut) + System.setErr(Utils.originalErr) + return ret + } + // -------------------------------------------------------------------------------- } \ No newline at end of file
