Repository: systemml
Updated Branches:
  refs/heads/master 2c9694dec -> ac1cf093a


[SYSTEMML-1808] [SYSTEMML-1658] Visualize Hop DAG for explaining the optimizer

- Also added an utility to print java output in notebook.
- Fixed a bug in dmlFromResource.

Closes #596.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ac1cf093
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ac1cf093
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ac1cf093

Branch: refs/heads/master
Commit: ac1cf093ad0b47cb6a0f0d48c4deb276b4ae1fa6
Parents: 2c9694d
Author: Niketan Pansare <[email protected]>
Authored: Thu Aug 3 09:06:24 2017 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Thu Aug 3 10:06:24 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/api/mlcontext/MLContext.java   |  25 +-
 .../sysml/api/mlcontext/MLContextUtil.java      | 108 +++++++
 .../sysml/api/mlcontext/ScriptExecutor.java     |  46 ++-
 .../context/SparkExecutionContext.java          |   2 +-
 .../sysml/runtime/instructions/Instruction.java |  20 ++
 .../java/org/apache/sysml/utils/Explain.java    | 301 +++++++++++++++++++
 src/main/python/systemml/mlcontext.py           | 100 +++++-
 .../scala/org/apache/sysml/api/ml/Utils.scala   |  61 ++++
 8 files changed, 646 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
index b35faa6..f74d593 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
@@ -46,7 +46,6 @@ import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.utils.Explain.ExplainType;
 import org.apache.sysml.utils.MLContextProxy;
-
 /**
  * The MLContext API offers programmatic access to SystemML on Spark from
  * languages such as Scala, Java, and Python.
@@ -287,6 +286,8 @@ public class MLContext {
        public void resetConfig() {
                MLContextUtil.setDefaultConfig();
        }
+       
+       
 
        /**
         * Set configuration property, such as
@@ -305,7 +306,8 @@ public class MLContext {
                        throw new MLContextException(e);
                }
        }
-
+       
+       
        /**
         * Execute a DML or PYDML Script.
         *
@@ -357,6 +359,16 @@ public class MLContext {
                        throw new MLContextException("Exception when executing 
script", e);
                }
        }
+       
+       /**
+        * Sets the script that is being executed
+        * 
+        * @param executionScript
+        *            script that is being executed
+        */
+       public void setExecutionScript(Script executionScript) {
+               this.executionScript = executionScript;
+       }
 
        /**
         * Set SystemML configuration based on a configuration file.
@@ -489,6 +501,15 @@ public class MLContext {
        }
 
        /**
+        * Whether or not the "force" GPU mode is enabled.
+        *
+        * @return true if enabled, false otherwise
+        */
+       public boolean isForceGPU() {
+               return this.forceGPU;
+       }
+       
+       /**
         * Used internally by MLContextProxy.
         *
         */

http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index 2c9566c..51d38a5 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -23,6 +23,7 @@ import java.io.File;
 import java.io.FileNotFoundException;
 import java.net.URL;
 import java.util.ArrayList;
+import java.util.Date;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map.Entry;
@@ -35,6 +36,7 @@ import javax.xml.parsers.DocumentBuilderFactory;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.text.WordUtils;
+import org.apache.spark.SparkConf;
 import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
@@ -52,8 +54,11 @@ import org.apache.sysml.conf.CompilerConfig;
 import org.apache.sysml.conf.CompilerConfig.ConfigType;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.conf.DMLConfig;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.parser.LanguageException;
 import org.apache.sysml.parser.ParseException;
 import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
 import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
@@ -63,6 +68,8 @@ import org.apache.sysml.runtime.controlprogram.ProgramBlock;
 import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
 import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.instructions.Instruction;
 import org.apache.sysml.runtime.instructions.cp.BooleanObject;
 import org.apache.sysml.runtime.instructions.cp.Data;
@@ -73,6 +80,7 @@ import 
org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysml.runtime.matrix.data.FrameBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.utils.Explain;
 import org.apache.sysml.utils.MLContextProxy;
 import org.w3c.dom.Document;
 import org.w3c.dom.Node;
@@ -83,6 +91,106 @@ import org.w3c.dom.NodeList;
  *
  */
 public final class MLContextUtil {
+       
+       /**
+        * Get HOP DAG in dot format for a DML or PYDML Script.
+        *
+        * @param mlCtx
+        *            MLContext object.
+        * @param script
+        *            The DML or PYDML Script object to execute.
+        * @param lines
+        *            Only display the hops that have begin and end line number
+        *            equals to the given integers.
+        * @param performHOPRewrites
+        *            should perform static rewrites, perform
+        *            intra-/inter-procedural analysis to propagate size 
information
+        *            into functions and apply dynamic rewrites
+        * @param withSubgraph
+        *            If false, the dot graph will be created without subgraphs 
for
+        *            statement blocks.
+        * @return hop DAG in dot format
+        * @throws LanguageException
+        *             if error occurs
+        * @throws DMLRuntimeException
+        *             if error occurs
+        * @throws HopsException
+        *             if error occurs
+        */
+       public static String getHopDAG(MLContext mlCtx, Script script, 
ArrayList<Integer> lines,
+                       boolean performHOPRewrites, boolean withSubgraph) 
throws HopsException, DMLRuntimeException,
+                       LanguageException {
+               return getHopDAG(mlCtx, script, lines, null, 
performHOPRewrites, withSubgraph);
+       }
+
+       /**
+        * Get HOP DAG in dot format for a DML or PYDML Script.
+        *
+        * @param mlCtx
+        *            MLContext object.
+        * @param script
+        *            The DML or PYDML Script object to execute.
+        * @param lines
+        *            Only display the hops that have begin and end line number
+        *            equals to the given integers.
+        * @param newConf
+        *            Spark Configuration.
+        * @param performHOPRewrites
+        *            should perform static rewrites, perform
+        *            intra-/inter-procedural analysis to propagate size 
information
+        *            into functions and apply dynamic rewrites
+        * @param withSubgraph
+        *            If false, the dot graph will be created without subgraphs 
for
+        *            statement blocks.
+        * @return hop DAG in dot format
+        * @throws LanguageException
+        *             if error occurs
+        * @throws DMLRuntimeException
+        *             if error occurs
+        * @throws HopsException
+        *             if error occurs
+        */
+       public static String getHopDAG(MLContext mlCtx, Script script, 
ArrayList<Integer> lines, SparkConf newConf,
+                       boolean performHOPRewrites, boolean withSubgraph) 
throws HopsException, DMLRuntimeException,
+                       LanguageException {
+               SparkConf oldConf = 
mlCtx.getSparkSession().sparkContext().getConf();
+               SparkExecutionContext.SparkClusterConfig systemmlConf = 
SparkExecutionContext.getSparkClusterConfig();
+               long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory();
+               try {
+                       if (newConf != null) {
+                               systemmlConf.analyzeSparkConfiguation(newConf);
+                               
InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory"));
+                       }
+                       ScriptExecutor scriptExecutor = new ScriptExecutor();
+                       
scriptExecutor.setExecutionType(mlCtx.getExecutionType());
+                       scriptExecutor.setGPU(mlCtx.isGPU());
+                       scriptExecutor.setForceGPU(mlCtx.isForceGPU());
+                       scriptExecutor.setInit(mlCtx.isInitBeforeExecution());
+                       if (mlCtx.isInitBeforeExecution()) {
+                               mlCtx.setInitBeforeExecution(false);
+                       }
+                       
scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable());
+
+                       Long time = new Long((new Date()).getTime());
+                       if ((script.getName() == null) || 
(script.getName().equals(""))) {
+                               script.setName(time.toString());
+                       }
+                       
+                       mlCtx.setExecutionScript(script);
+                       scriptExecutor.compile(script, performHOPRewrites);
+                       Explain.reset();
+                       // To deal with potential Py4J issues
+                       lines = lines.size() == 1 && lines.get(0) == -1 ? new 
ArrayList<Integer>() : lines;
+                       return Explain.getHopDAG(scriptExecutor.dmlProgram, 
lines, withSubgraph);
+               } catch (RuntimeException e) {
+                       throw new MLContextException("Exception when compiling 
script", e);
+               } finally {
+                       if (newConf != null) {
+                               systemmlConf.analyzeSparkConfiguation(oldConf);
+                               
InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory);
+                       }
+               }
+       }
 
        /**
         * Basic data types supported by the MLContext API

http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java 
b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index 467e94e..7e78891 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -263,9 +263,16 @@ public class ScriptExecutor {
                DMLScript.USE_ACCELERATOR = oldGPU;
                DMLScript.STATISTICS_COUNT = 
DMLOptions.defaultOptions.statsCount;
        }
-
+       
+       public void compile(Script script) {
+               compile(script, true);
+       }
+       
        /**
-        * Execute a DML or PYDML script. This is broken down into the following
+        * Compile a DML or PYDML script. This will help analysis of DML 
programs
+        * that have dynamic recompilation flag set to false without actually 
executing it. 
+        * 
+        * This is broken down into the following
         * primary methods:
         *
         * <ol>
@@ -283,16 +290,14 @@ public class ScriptExecutor {
         * <li>{@link #countCompiledMRJobsAndSparkInstructions()}</li>
         * <li>{@link #initializeCachingAndScratchSpace()}</li>
         * <li>{@link #cleanupRuntimeProgram()}</li>
-        * <li>{@link #createAndInitializeExecutionContext()}</li>
-        * <li>{@link #executeRuntimeProgram()}</li>
-        * <li>{@link #cleanupAfterExecution()}</li>
         * </ol>
         *
         * @param script
-        *            the DML or PYDML script to execute
-        * @return the results as a MLResults object
+        *            the DML or PYDML script to compile
+        * @param performHOPRewrites
+        *            should perform static rewrites, perform 
intra-/inter-procedural analysis to propagate size information into functions 
and apply dynamic rewrites
         */
-       public MLResults execute(Script script) {
+       public void compile(Script script, boolean performHOPRewrites) {
 
                // main steps in script execution
                setup(script);
@@ -303,7 +308,8 @@ public class ScriptExecutor {
                liveVariableAnalysis();
                validateScript();
                constructHops();
-               rewriteHops();
+               if(performHOPRewrites)
+                       rewriteHops();
                rewritePersistentReadsAndWrites();
                constructLops();
                generateRuntimeProgram();
@@ -315,6 +321,28 @@ public class ScriptExecutor {
                if (statistics) {
                        Statistics.stopCompileTimer();
                }
+       }
+
+
+       /**
+        * Execute a DML or PYDML script. This is broken down into the following
+        * primary methods:
+        *
+        * <ol>
+        * <li>{@link #compile(Script)}</li>
+        * <li>{@link #createAndInitializeExecutionContext()}</li>
+        * <li>{@link #executeRuntimeProgram()}</li>
+        * <li>{@link #cleanupAfterExecution()}</li>
+        * </ol>
+        *
+        * @param script
+        *            the DML or PYDML script to execute
+        * @return the results as a MLResults object
+        */
+       public MLResults execute(Script script) {
+
+               // main steps in script execution
+               compile(script);
 
                try {
                        createAndInitializeExecutionContext();

http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 6f2f766..d1ff7d8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1378,7 +1378,7 @@ public class SparkExecutionContext extends 
ExecutionContext
         * degree of parallelism. This configuration abstracts legacy (< Spark 
1.6) and current
         * configurations and provides a unified view.
         */
-       private static class SparkClusterConfig
+       public static class SparkClusterConfig
        {
                //broadcasts are stored in mem-and-disk in data space, this 
config
                //defines the fraction of data space to be used as broadcast 
budget

http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java
index 6db8c7f..374f81c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java
@@ -63,6 +63,26 @@ public abstract class Instruction
        protected int beginCol = -1; 
        protected int endCol = -1;
        
+       public String getFilename() {
+               return filename;
+       }
+
+       public int getBeginLine() {
+               return beginLine;
+       }
+
+       public int getEndLine() {
+               return endLine;
+       }
+
+       public int getBeginColumn() {
+               return beginCol;
+       }
+
+       public int getEndColumn() {
+               return endCol;
+       }
+
        public void setType (INSTRUCTION_TYPE tp ) {
                type = tp;
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java 
b/src/main/java/org/apache/sysml/utils/Explain.java
index a2e843a..01b59f7 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -26,10 +26,16 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Map.Entry;
 
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg;
 import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
@@ -266,6 +272,50 @@ public class Explain
        
                return sb.toString();
        }
+       
+       public static String getHopDAG(DMLProgram prog, ArrayList<Integer> 
lines, boolean withSubgraph)
+                       throws HopsException, DMLRuntimeException, 
LanguageException {
+               StringBuilder sb = new StringBuilder();
+               StringBuilder nodes = new StringBuilder();
+
+               // create header
+               sb.append("digraph {");
+
+               // Explain functions (if exists)
+               if (prog.hasFunctionStatementBlocks()) {
+
+                       // show function call graph
+                       // FunctionCallGraph fgraph = new 
FunctionCallGraph(prog);
+                       // sb.append(explainFunctionCallGraph(fgraph, new 
HashSet<String>(),
+                       // null, 3));
+
+                       // show individual functions
+                       for (String namespace : prog.getNamespaces().keySet()) {
+                               for (String fname : 
prog.getFunctionStatementBlocks(namespace).keySet()) {
+                                       FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(namespace, fname);
+                                       FunctionStatement fstmt = 
(FunctionStatement) fsb.getStatement(0);
+                                       String fkey = 
DMLProgram.constructFunctionKey(namespace, fname);
+
+                                       if (!(fstmt instanceof 
ExternalFunctionStatement)) {
+                                               addSubGraphHeader(sb, 
withSubgraph);
+                                               for (StatementBlock current : 
fstmt.getBody())
+                                                       
sb.append(getHopDAG(current, nodes, lines, withSubgraph));
+                                               String label = "FUNCTION " + 
fkey + " recompile=" + fsb.isRecompileOnce() + "\n";
+                                               addSubGraphFooter(sb, 
withSubgraph, label);
+                                       }
+                               }
+                       }
+               }
+
+               // Explain main program
+               for (StatementBlock sblk : prog.getStatementBlocks())
+                       sb.append(getHopDAG(sblk, nodes, lines, withSubgraph));
+
+               sb.append(nodes);
+               sb.append("rankdir = \"BT\"\n");
+               sb.append("}\n");
+               return sb.toString();
+       }
 
        public static String explain( Program rtprog ) throws HopsException {
                return explain(rtprog, null);
@@ -466,6 +516,128 @@ public class Explain
        //////////////
        // internal explain HOPS
 
+       private static int clusterID = 0;
+
+       public static void reset() {
+               clusterID = 0;
+       }
+
+       private static void addSubGraphHeader(StringBuilder builder, boolean 
withSubgraph) {
+               if (withSubgraph) {
+                       builder.append("subgraph cluster_" + (clusterID++) + " 
{\n");
+               }
+       }
+
+       private static void addSubGraphFooter(StringBuilder builder, boolean 
withSubgraph, String label) {
+               if (withSubgraph) {
+                       builder.append("label = \"" + label + "\";\n");
+                       builder.append("}\n");
+               }
+       }
+
+       private static StringBuilder getHopDAG(StatementBlock sb, StringBuilder 
nodes, ArrayList<Integer> lines,
+                       boolean withSubgraph) throws HopsException, 
DMLRuntimeException {
+               StringBuilder builder = new StringBuilder();
+
+               if (sb instanceof WhileStatementBlock) {
+                       addSubGraphHeader(builder, withSubgraph);
+
+                       WhileStatementBlock wsb = (WhileStatementBlock) sb;
+                       String label = null;
+                       if (!wsb.getUpdateInPlaceVars().isEmpty())
+                               label = "WHILE (lines " + wsb.getBeginLine() + 
"-" + wsb.getEndLine() + ") in-place="
+                                               + 
wsb.getUpdateInPlaceVars().toString() + "";
+                       else
+                               label = "WHILE (lines " + wsb.getBeginLine() + 
"-" + wsb.getEndLine() + ")";
+                       // TODO: Don't show predicate hops for now
+                       // builder.append(explainHop(wsb.getPredicateHops()));
+
+                       WhileStatement ws = (WhileStatement) sb.getStatement(0);
+                       for (StatementBlock current : ws.getBody())
+                               builder.append(getHopDAG(current, nodes, lines, 
withSubgraph));
+
+                       addSubGraphFooter(builder, withSubgraph, label);
+               } else if (sb instanceof IfStatementBlock) {
+                       addSubGraphHeader(builder, withSubgraph);
+                       IfStatementBlock ifsb = (IfStatementBlock) sb;
+                       String label = "IF (lines " + ifsb.getBeginLine() + "-" 
+ ifsb.getEndLine() + ")";
+                       // TODO: Don't show predicate hops for now
+                       // builder.append(explainHop(ifsb.getPredicateHops(), 
level+1));
+
+                       IfStatement ifs = (IfStatement) sb.getStatement(0);
+                       for (StatementBlock current : ifs.getIfBody()) {
+                               builder.append(getHopDAG(current, nodes, lines, 
withSubgraph));
+                               addSubGraphFooter(builder, withSubgraph, label);
+                       }
+                       if (!ifs.getElseBody().isEmpty()) {
+                               addSubGraphHeader(builder, withSubgraph);
+                               label = "ELSE (lines " + ifsb.getBeginLine() + 
"-" + ifsb.getEndLine() + ")";
+
+                               for (StatementBlock current : ifs.getElseBody())
+                                       builder.append(getHopDAG(current, 
nodes, lines, withSubgraph));
+                               addSubGraphFooter(builder, withSubgraph, label);
+                       }
+               } else if (sb instanceof ForStatementBlock) {
+                       ForStatementBlock fsb = (ForStatementBlock) sb;
+                       addSubGraphHeader(builder, withSubgraph);
+                       String label = "";
+                       if (sb instanceof ParForStatementBlock) {
+                               if (!fsb.getUpdateInPlaceVars().isEmpty())
+                                       label = "PARFOR (lines " + 
fsb.getBeginLine() + "-" + fsb.getEndLine() + ") in-place="
+                                                       + 
fsb.getUpdateInPlaceVars().toString() + "";
+                               else
+                                       label = "PARFOR (lines " + 
fsb.getBeginLine() + "-" + fsb.getEndLine() + ")";
+                       } else {
+                               if (!fsb.getUpdateInPlaceVars().isEmpty())
+                                       label = "FOR (lines " + 
fsb.getBeginLine() + "-" + fsb.getEndLine() + ") in-place="
+                                                       + 
fsb.getUpdateInPlaceVars().toString() + "";
+                               else
+                                       label = "FOR (lines " + 
fsb.getBeginLine() + "-" + fsb.getEndLine() + ")";
+                       }
+                       // TODO: Don't show predicate hops for now
+                       // if (fsb.getFromHops() != null)
+                       // builder.append(explainHop(fsb.getFromHops(), 
level+1));
+                       // if (fsb.getToHops() != null)
+                       // builder.append(explainHop(fsb.getToHops(), level+1));
+                       // if (fsb.getIncrementHops() != null)
+                       // builder.append(explainHop(fsb.getIncrementHops(), 
level+1));
+
+                       ForStatement fs = (ForStatement) sb.getStatement(0);
+                       for (StatementBlock current : fs.getBody())
+                               builder.append(getHopDAG(current, nodes, lines, 
withSubgraph));
+                       addSubGraphFooter(builder, withSubgraph, label);
+
+               } else if (sb instanceof FunctionStatementBlock) {
+                       FunctionStatement fsb = (FunctionStatement) 
sb.getStatement(0);
+                       addSubGraphHeader(builder, withSubgraph);
+                       String label = "Function (lines " + fsb.getBeginLine() 
+ "-" + fsb.getEndLine() + ")";
+                       for (StatementBlock current : fsb.getBody())
+                               builder.append(getHopDAG(current, nodes, lines, 
withSubgraph));
+                       addSubGraphFooter(builder, withSubgraph, label);
+               } else {
+                       // For generic StatementBlock
+                       if (sb.requiresRecompilation()) {
+                               addSubGraphHeader(builder, withSubgraph);
+                       }
+                       ArrayList<Hop> hopsDAG = sb.get_hops();
+                       if (hopsDAG != null && !hopsDAG.isEmpty()) {
+                               Hop.resetVisitStatus(hopsDAG);
+                               for (Hop hop : hopsDAG)
+                                       builder.append(getHopDAG(hop, nodes, 
lines, withSubgraph));
+                               Hop.resetVisitStatus(hopsDAG);
+                       }
+
+                       if (sb.requiresRecompilation()) {
+                               builder.append("style=filled;\n");
+                               builder.append("color=lightgrey;\n");
+                               String label = "(lines " + sb.getBeginLine() + 
"-" + sb.getEndLine() + ") [recompile="
+                                               + sb.requiresRecompilation() + 
"]";
+                               addSubGraphFooter(builder, withSubgraph, label);
+                       }
+               }
+               return builder;
+       }
+
        private static String explainStatementBlock(StatementBlock sb, int 
level) 
                throws HopsException, DMLRuntimeException 
        {
@@ -636,6 +808,134 @@ public class Explain
                
                return sb.toString();
        }
+       
+       private static boolean isInRange(Hop hop, ArrayList<Integer> lines) {
+               boolean isInRange = lines.size() == 0 ? true : false;
+               for (int lineNum : lines) {
+                       if (hop.getBeginLine() == lineNum && lineNum == 
hop.getEndLine()) {
+                               return true;
+                       }
+               }
+               return isInRange;
+       }
+
+       private static StringBuilder getHopDAG(Hop hop, StringBuilder nodes, 
ArrayList<Integer> lines, boolean withSubgraph)
+                       throws DMLRuntimeException {
+               StringBuilder sb = new StringBuilder();
+               if (hop.isVisited() || (!SHOW_LITERAL_HOPS && hop instanceof 
LiteralOp))
+                       return sb;
+
+               for (Hop input : hop.getInput()) {
+                       if ((SHOW_LITERAL_HOPS || !(input instanceof 
LiteralOp)) && isInRange(hop, lines)) {
+                               String edgeLabel = 
showMem(input.getOutputMemEstimate(), true);
+                               sb.append("h" + input.getHopID() + " -> h" + 
hop.getHopID() + " [label=\"" + edgeLabel + "\"];\n");
+                       }
+               }
+               for (Hop input : hop.getInput())
+                       sb.append(getHopDAG(input, nodes, lines, withSubgraph));
+
+               if (isInRange(hop, lines)) {
+                       nodes.append("h" + hop.getHopID() + "[label=\"" + 
getNodeLabel(hop) + "\", " + "shape=\""
+                                       + getNodeShape(hop) + "\", color=\"" + 
getNodeColor(hop) + "\", tooltip=\"" + getNodeToolTip(hop)
+                                       + "\"];\n");
+               }
+               hop.setVisited();
+
+               return sb;
+       }
+
+       private static String getNodeLabel(Hop hop) {
+               StringBuilder sb = new StringBuilder();
+               sb.append(hop.getOpString());
+               if (hop instanceof AggBinaryOp) {
+                       AggBinaryOp aggBinOp = (AggBinaryOp) hop;
+                       if (aggBinOp.getMMultMethod() != null)
+                               sb.append(" " + 
aggBinOp.getMMultMethod().name() + " ");
+               }
+               // data flow properties
+               if (SHOW_DATA_FLOW_PROPERTIES) {
+                       if (hop.requiresReblock() && hop.requiresCheckpoint())
+                               sb.append(", rblk,chkpt");
+                       else if (hop.requiresReblock())
+                               sb.append(", rblk");
+                       else if (hop.requiresCheckpoint())
+                               sb.append(", chkpt");
+               }
+               if (hop.getFilename() == null) {
+                       sb.append("[" + hop.getBeginLine() + ":" + 
hop.getBeginColumn() + "-" + hop.getEndLine() + ":"
+                                       + hop.getEndColumn() + "]");
+               } else {
+                       sb.append("[" + hop.getFilename() + " " + 
hop.getBeginLine() + ":" + hop.getBeginColumn() + "-"
+                                       + hop.getEndLine() + ":" + 
hop.getEndColumn() + "]");
+               }
+
+               if (hop.getUpdateType().isInPlace())
+                       sb.append("," + 
hop.getUpdateType().toString().toLowerCase());
+               return sb.toString();
+       }
+
+       private static String getNodeToolTip(Hop hop) {
+               StringBuilder sb = new StringBuilder();
+               if (hop.getExecType() != null) {
+                       sb.append(hop.getExecType().name());
+               }
+               sb.append("[" + hop.getDim1() + " X " + hop.getDim2() + "], 
nnz=" + hop.getNnz());
+               sb.append(", mem= [in=");
+               sb.append(showMem(hop.getInputMemEstimate(), false));
+               sb.append(", inter=");
+               sb.append(showMem(hop.getIntermediateMemEstimate(), false));
+               sb.append(", out=");
+               sb.append(showMem(hop.getOutputMemEstimate(), false));
+               sb.append(" -> ");
+               sb.append(showMem(hop.getMemEstimate(), true));
+               sb.append("]");
+               return sb.toString();
+       }
+
+       private static String getNodeShape(Hop hop) {
+               String shape = "octagon";
+               if (hop.getExecType() != null) {
+                       switch (hop.getExecType()) {
+                       case CP:
+                               shape = "ellipse";
+                               break;
+                       case SPARK:
+                               shape = "box";
+                               break;
+                       case GPU:
+                               shape = "trapezium";
+                               break;
+                       case MR:
+                               shape = "parallelogram";
+                               break;
+                       default:
+                               shape = "octagon";
+                               break;
+                       }
+               }
+               return shape;
+       }
+
+       private static String getNodeColor(Hop hop) {
+               if (hop instanceof DataOp) {
+                       DataOp dOp = (DataOp) hop;
+                       if (dOp.getDataOpType() == DataOpTypes.PERSISTENTREAD 
|| dOp.getDataOpType() == DataOpTypes.TRANSIENTREAD) {
+                               return "wheat2";
+                       } else if (dOp.getDataOpType() == 
DataOpTypes.PERSISTENTWRITE
+                                       || dOp.getDataOpType() == 
DataOpTypes.TRANSIENTWRITE) {
+                               return "wheat4";
+                       }
+               } else if (hop instanceof AggBinaryOp) {
+                       return "orangered2";
+               } else if (hop instanceof BinaryOp) {
+                       return "royalblue2";
+               } else if (hop instanceof ReorgOp) {
+                       return "green";
+               } else if (hop instanceof UnaryOp) {
+                       return "yellow";
+               }
+               return "black";
+       }
 
        //////////////
        // internal explain CNODE
@@ -867,6 +1167,7 @@ public class Explain
                        
                        sb.append( offsetInst );
                        sb.append( tmp );
+                       
                        sb.append( '\n' );
                }
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/python/systemml/mlcontext.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mlcontext.py 
b/src/main/python/systemml/mlcontext.py
index 1e79648..b5a6bf9 100644
--- a/src/main/python/systemml/mlcontext.py
+++ b/src/main/python/systemml/mlcontext.py
@@ -19,7 +19,11 @@
 #
 #-------------------------------------------------------------
 
-__all__ = ['MLResults', 'MLContext', 'Script', 'dml', 'pydml', 
'dmlFromResource', 'pydmlFromResource', 'dmlFromFile', 'pydmlFromFile', 
'dmlFromUrl', 'pydmlFromUrl',  '_java2py', 'Matrix']
+# Methods to create Script object
+script_factory_methods = [ 'dml', 'pydml', 'dmlFromResource', 
'pydmlFromResource', 'dmlFromFile', 'pydmlFromFile', 'dmlFromUrl', 
'pydmlFromUrl' ]
+# Utility methods
+util_methods = [ 'jvm_stdout', '_java2py',  'getHopDAG' ]
+__all__ = ['MLResults', 'MLContext', 'Script', 'Matrix' ] + 
script_factory_methods + util_methods
 
 import os
 
@@ -27,13 +31,16 @@ try:
     import py4j.java_gateway
     from py4j.java_gateway import JavaObject
     from pyspark import SparkContext
+    from pyspark.conf import SparkConf
     import pyspark.mllib.common
 except ImportError:
     raise ImportError('Unable to import `pyspark`. Hint: Make sure you are 
running with PySpark.')
 
 from .converters import *
 from .classloader import *
+import threading, time
 
+_loadedSystemML = False
 def _get_spark_context():
     """
     Internal method to get already initialized SparkContext.
@@ -44,10 +51,93 @@ def _get_spark_context():
         SparkContext
     """
     if SparkContext._active_spark_context is not None:
-        return SparkContext._active_spark_context
+        sc = SparkContext._active_spark_context
+        if not _loadedSystemML:
+            createJavaObject(sc, 'dummy')
+            _loadedSystemML = True
+        return sc
     else:
         raise Exception('Expected spark context to be created.')
 
+# This is useful utility class to get the output of the driver JVM from within 
a Jupyter notebook
+# Example usage:
+# with jvm_stdout():
+#    ml.execute(script)
+class jvm_stdout(object):
+    """
+    This is useful utility class to get the output of the driver JVM from 
within a Jupyter notebook
+
+    Parameters
+    ----------
+    parallel_flush: boolean
+        Should flush the stdout in parallel
+    """
+    def __init__(self, parallel_flush=False):
+        self.util = 
SparkContext._active_spark_context._jvm.org.apache.sysml.api.ml.Utils()
+        self.parallel_flush = parallel_flush
+        self.t = threading.Thread(target=self.flush_stdout)
+        self.stop = False
+        
+    def flush_stdout(self):
+        while not self.stop: 
+            time.sleep(1) # flush stdout every 1 second
+            str = self.util.flushStdOut()
+            if str != '':
+                str = str[:-1] if str.endswith('\n') else str
+                print(str)
+    
+    def __enter__(self):
+        self.util.startRedirectStdOut()
+        if self.parallel_flush:
+            self.t.start()
+
+    def __exit__(self, *args):
+        if self.parallel_flush:
+            self.stop = True
+            self.t.join()
+        print(self.util.stopRedirectStdOut())
+        
+
+def getHopDAG(ml, script, lines=None, conf=None, apply_rewrites=True, 
with_subgraph=False):
+    """
+    Compile a DML / PyDML script.
+
+    Parameters
+    ----------
+    ml: MLContext instance
+        MLContext instance.
+        
+    script: Script instance
+        Script instance defined with the appropriate input and output 
variables.
+    
+    lines: list of integers
+        Optional: only display the hops that have begin and end line number 
equals to the given integers.
+    
+    conf: SparkConf instance
+        Optional spark configuration
+        
+    apply_rewrites: boolean
+        If True, perform static rewrites, perform intra-/inter-procedural 
analysis to propagate size information into functions and apply dynamic rewrites
+    
+    with_subgraph: boolean
+        If False, the dot graph will be created without subgraphs for 
statement blocks. 
+    
+    Returns
+    -------
+    hopDAG: string
+        hop DAG in dot format 
+    """
+    if not isinstance(script, Script):
+        raise ValueError("Expected script to be an instance of Script")
+    scriptString = script.scriptString
+    script_java = script.script_java
+    lines = [ int(x) for x in lines ] if lines is not None else [int(-1)]
+    sc = _get_spark_context()
+    if conf is not None:
+        hopDAG = 
sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, 
script_java, lines, conf._jconf, apply_rewrites, with_subgraph)
+    else:
+        hopDAG = 
sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, 
script_java, lines, apply_rewrites, with_subgraph)
+    return hopDAG
 
 def dml(scriptString):
     """
@@ -330,9 +420,9 @@ class Script(object):
                 self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString)
             elif scriptFormat == "file" and self.scriptType == "pydml":
                 self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString)
-            elif scriptFormat == "file" and self.scriptType == "dml":
+            elif isResource and self.scriptType == "dml":
                 self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(scriptString)
-            elif scriptFormat == "file" and self.scriptType == "pydml":
+            elif isResource and self.scriptType == "pydml":
                 self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(scriptString)
             elif scriptFormat == "string" and self.scriptType == "dml":
                 self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString)
@@ -605,7 +695,7 @@ class MLContext(object):
 
     def __repr__(self):
         return "MLContext"
-    
+        
     def execute(self, script):
         """
         Execute a DML / PyDML script.

http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/scala/org/apache/sysml/api/ml/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/Utils.scala 
b/src/main/scala/org/apache/sysml/api/ml/Utils.scala
index da3edf5..a804f64 100644
--- a/src/main/scala/org/apache/sysml/api/ml/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/Utils.scala
@@ -18,8 +18,69 @@
  */
 package org.apache.sysml.api.ml
 
+import org.apache.spark.api.java.JavaPairRDD
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+
+object Utils {
+  val originalOut = System.out
+  val originalErr = System.err
+}
 class Utils {
   def checkIfFileExists(filePath:String):Boolean = {
     return 
org.apache.sysml.runtime.util.MapReduceTool.existsFileOnHDFS(filePath)
   }
+  
+  // 
--------------------------------------------------------------------------------
+  // Simple utility function to print the information about our binary blocked 
format
+  def getBinaryBlockInfo(binaryBlocks:JavaPairRDD[MatrixIndexes, 
MatrixBlock]):String = {
+    val sb = new StringBuilder
+    var partitionIndex = 0
+    for(str <- binaryBlocks.rdd.mapPartitions(binaryBlockIteratorToString(_), 
true).collect) {
+      sb.append("-------------------------------------\n")
+      sb.append("Partition " + partitionIndex  + ":\n")
+      sb.append(str)
+      partitionIndex = partitionIndex + 1
+    }
+    sb.append("-------------------------------------\n")
+    return sb.toString()
+  }
+  def binaryBlockIteratorToString(it: Iterator[(MatrixIndexes, MatrixBlock)]): 
Iterator[String] = {
+    val sb = new StringBuilder
+    for(entry <- it) {
+      val mi = entry._1
+      val mb = entry._2
+      sb.append(mi.toString);
+               sb.append(" sparse? = ");
+               sb.append(mb.isInSparseFormat());
+               if(mb.isUltraSparse)
+                 sb.append(" (ultra-sparse)") 
+               sb.append(", nonzeros = ");
+               sb.append(mb.getNonZeros);
+               sb.append(", dimensions = ");
+               sb.append(mb.getNumRows);
+               sb.append(" X ");
+               sb.append(mb.getNumColumns);
+               sb.append("\n");
+    }
+    List[String](sb.toString).iterator
+  }
+  val baos = new java.io.ByteArrayOutputStream()
+  val baes = new java.io.ByteArrayOutputStream()
+  def startRedirectStdOut():Unit = {  
+    System.setOut(new java.io.PrintStream(baos));
+    System.setErr(new java.io.PrintStream(baes));
+  }
+  def flushStdOut():String = {
+    val ret = baos.toString() + baes.toString()
+    baos.reset(); baes.reset()
+    return ret
+  }
+  def stopRedirectStdOut():String = {
+    val ret = baos.toString() + baes.toString()
+    System.setOut(Utils.originalOut)
+    System.setErr(Utils.originalErr)
+    return ret
+  }
+  // 
--------------------------------------------------------------------------------
 }
\ No newline at end of file

Reply via email to