Repository: systemml
Updated Branches:
  refs/heads/master f58717564 -> bc16b9e3d


[SYSTEMML-1662] Fix rewrite issues discovered by hop validator, part 2

This patch addresses additional issues discovered by the extended hop
dag validator. Specifically this includes (1) proper handling of visit
status in dag-splitting rewrites (by avoiding dangling references across
dags), and (2) proper parent-child linking on rewriting persistent to
transient reads and writes (mlcontext, jmlc).

Furthermore, this includes minor fixes for the hop dag validator
(regarding the expected number of inputs to function ops) and
GPU-related compiler warnings.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/bc16b9e3
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/bc16b9e3
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/bc16b9e3

Branch: refs/heads/master
Commit: bc16b9e3db290edbb62e9686998435f8eea66be3
Parents: f587175
Author: Matthias Boehm <[email protected]>
Authored: Sat Jun 10 00:46:19 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 10 13:52:53 2017 -0700

----------------------------------------------------------------------
 .../sysml/api/mlcontext/ScriptExecutor.java     | 30 ++++----------------
 src/main/java/org/apache/sysml/hops/DataOp.java |  5 ++--
 .../java/org/apache/sysml/hops/LiteralOp.java   | 17 +++++++----
 .../sysml/hops/rewrite/HopDagValidator.java     |  6 ++--
 .../RewriteSplitDagDataDependentOperators.java  |  8 +++++-
 .../rewrite/RewriteSplitDagUnknownCSVRead.java  | 25 ++++++++++++----
 .../controlprogram/ParForProgramBlock.java      |  2 --
 .../context/ExecutionContext.java               |  1 -
 .../org/apache/sysml/test/gpu/GPUTests.java     |  1 +
 9 files changed, 50 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java 
b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index cd884f2..0035350 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -153,9 +153,7 @@ public class ScriptExecutor {
        protected void constructHops() {
                try {
                        dmlTranslator.constructHops(dmlProgram);
-               } catch (LanguageException e) {
-                       throw new MLContextException("Exception occurred while 
constructing HOPS (high-level operators)", e);
-               } catch (ParseException e) {
+               } catch (LanguageException | ParseException e) {
                        throw new MLContextException("Exception occurred while 
constructing HOPS (high-level operators)", e);
                }
        }
@@ -168,11 +166,7 @@ public class ScriptExecutor {
        protected void rewriteHops() {
                try {
                        dmlTranslator.rewriteHopsDAG(dmlProgram);
-               } catch (LanguageException e) {
-                       throw new MLContextException("Exception occurred while 
rewriting HOPS (high-level operators)", e);
-               } catch (HopsException e) {
-                       throw new MLContextException("Exception occurred while 
rewriting HOPS (high-level operators)", e);
-               } catch (ParseException e) {
+               } catch (LanguageException | HopsException | ParseException e) {
                        throw new MLContextException("Exception occurred while 
rewriting HOPS (high-level operators)", e);
                }
        }
@@ -199,13 +193,7 @@ public class ScriptExecutor {
        protected void constructLops() {
                try {
                        dmlTranslator.constructLops(dmlProgram);
-               } catch (ParseException e) {
-                       throw new MLContextException("Exception occurred while 
constructing LOPS (low-level operators)", e);
-               } catch (LanguageException e) {
-                       throw new MLContextException("Exception occurred while 
constructing LOPS (low-level operators)", e);
-               } catch (HopsException e) {
-                       throw new MLContextException("Exception occurred while 
constructing LOPS (low-level operators)", e);
-               } catch (LopsException e) {
+               } catch (ParseException | LanguageException | HopsException | 
LopsException e) {
                        throw new MLContextException("Exception occurred while 
constructing LOPS (low-level operators)", e);
                }
        }
@@ -218,13 +206,7 @@ public class ScriptExecutor {
        protected void generateRuntimeProgram() {
                try {
                        runtimeProgram = dmlProgram.getRuntimeProgram(config);
-               } catch (LanguageException e) {
-                       throw new MLContextException("Exception occurred while 
generating runtime program", e);
-               } catch (DMLRuntimeException e) {
-                       throw new MLContextException("Exception occurred while 
generating runtime program", e);
-               } catch (LopsException e) {
-                       throw new MLContextException("Exception occurred while 
generating runtime program", e);
-               } catch (IOException e) {
+               } catch (LanguageException | DMLRuntimeException | 
LopsException | IOException e) {
                        throw new MLContextException("Exception occurred while 
generating runtime program", e);
                }
        }
@@ -480,9 +462,7 @@ public class ScriptExecutor {
                        ProgramRewriter programRewriter = new 
ProgramRewriter(rewrite);
                        try {
                                
programRewriter.rewriteProgramHopDAGs(dmlProgram);
-                       } catch (LanguageException e) {
-                               throw new MLContextException("Exception 
occurred while rewriting persistent reads and writes", e);
-                       } catch (HopsException e) {
+                       } catch (LanguageException | HopsException e) {
                                throw new MLContextException("Exception 
occurred while rewriting persistent reads and writes", e);
                        }
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/DataOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DataOp.java 
b/src/main/java/org/apache/sysml/hops/DataOp.java
index bba0898..8653360 100644
--- a/src/main/java/org/apache/sysml/hops/DataOp.java
+++ b/src/main/java/org/apache/sysml/hops/DataOp.java
@@ -592,11 +592,10 @@ public class DataOp extends Hop
         * @param inputName The name of the input to remove
         */
        public void removeInput(String inputName) {
-
                int inputIndex = getParameterIndex(inputName);
-               _input.remove(inputIndex);
+               Hop tmp = _input.remove(inputIndex);
+               tmp._parent.remove(this);
                _paramIndexMap.remove(inputName);
-
                for (Entry<String, Integer> entry : _paramIndexMap.entrySet()) {
                        if (entry.getValue() > inputIndex) {
                                _paramIndexMap.put(entry.getKey(), 
(entry.getValue() - 1));

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/LiteralOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/LiteralOp.java 
b/src/main/java/org/apache/sysml/hops/LiteralOp.java
index 503d270..b96d032 100644
--- a/src/main/java/org/apache/sysml/hops/LiteralOp.java
+++ b/src/main/java/org/apache/sysml/hops/LiteralOp.java
@@ -30,7 +30,6 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 
 public class LiteralOp extends Hop 
 {
-       
        private double value_double = Double.NaN;
        private long value_long = Long.MAX_VALUE;
        private String value_string;
@@ -44,22 +43,30 @@ public class LiteralOp extends Hop
        
        public LiteralOp(double value) {
                super(String.valueOf(value), DataType.SCALAR, ValueType.DOUBLE);
-               this.value_double = value;
+               value_double = value;
        }
 
        public LiteralOp(long value) {
                super(String.valueOf(value), DataType.SCALAR, ValueType.INT);
-               this.value_long = value;
+               value_long = value;
        }
 
        public LiteralOp(String value) {
                super(value, DataType.SCALAR, ValueType.STRING);
-               this.value_string = value;
+               value_string = value;
        }
 
        public LiteralOp(boolean value) {
                super(String.valueOf(value), DataType.SCALAR, 
ValueType.BOOLEAN);
-               this.value_boolean = value;
+               value_boolean = value;
+       }
+       
+       public LiteralOp(LiteralOp that) {
+               super(that.getName(), that.getDataType(), that.getValueType());
+               value_double = that.value_double;
+               value_long = that.value_long;
+               value_string = that.value_string;
+               value_boolean = that.value_boolean;
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
index 52fb36f..8cb5e1e 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
@@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
 import org.apache.sysml.hops.DataOp;
+import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.LiteralOp;
@@ -85,6 +86,7 @@ public class HopDagValidator {
        private static void rValidateHop(final Hop hop, final ValidatorState 
state) throws HopsException {
                final long id = hop.getHopID();
 
+               //check visit status
                final boolean seen = !state.seen.add(id);
                check(seen == hop.isVisited(), hop,
                                "seen previously is %b but does not match hop 
visit status", seen);
@@ -107,8 +109,8 @@ public class HopDagValidator {
 
                //check empty children (other variable-length Hops must have at 
least one child)
                if( input.isEmpty() )
-                       check(hop instanceof DataOp || hop instanceof 
LiteralOp, hop,
-                                       "is not a dataop/literal but has no 
children");
+                       check(hop instanceof DataOp || hop instanceof 
FunctionOp || hop instanceof LiteralOp, hop,
+                                       "is not a dataop/functionop/literal but 
has no children");
 
                // check Hop has a legal arity (number of children)
                hop.checkArity();

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index 85d34dc..6c2fda9 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -20,8 +20,11 @@
 package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashSet;
 
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.DataOp;
@@ -63,7 +66,6 @@ import org.apache.sysml.runtime.matrix.data.Pair;
  */
 public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewriteRule
 {
-
        private static String _varnamePredix = "_sbcvar";
        private static IDSequence _seq = new IDSequence();
        
@@ -71,6 +73,10 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
        public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock 
sb, ProgramRewriteStatus state)
                throws HopsException 
        {
+               //DAG splits not required for forced single node
+               if( DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE )
+                       return new ArrayList<StatementBlock>(Arrays.asList(sb));
+               
                ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
        
                //collect all unknown csv reads hops

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
index 8396813..2e9847a 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
@@ -20,12 +20,16 @@
 package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.Hop.FileFormatTypes;
 import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.parser.DataIdentifier;
 import org.apache.sysml.parser.StatementBlock;
 import org.apache.sysml.parser.VariableSet;
@@ -45,6 +49,10 @@ public class RewriteSplitDagUnknownCSVRead extends 
StatementBlockRewriteRule
        public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock 
sb, ProgramRewriteStatus state)
                throws HopsException 
        {
+               //DAG splits not required for forced single node
+               if( DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE )
+                       return new ArrayList<StatementBlock>(Arrays.asList(sb));
+               
                ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
                
                //collect all unknown csv reads hops
@@ -66,16 +74,22 @@ public class RewriteSplitDagUnknownCSVRead extends 
StatementBlockRewriteRule
                                //move csv reads incl reblock to new statement 
block
                                //(and replace original persistent read with 
transient read)
                                ArrayList<Hop> sb1hops = new ArrayList<Hop>();  
                
-                               for( Hop c : cand )
+                               for( Hop reblock : cand )
                                {
-                                       Hop reblock = c;
                                        long rlen = reblock.getDim1();
                                        long clen = reblock.getDim2();
                                        long nnz = reblock.getNnz();
-                                       UpdateType update = c.getUpdateType();
+                                       UpdateType update = 
reblock.getUpdateType();
                                        long brlen = reblock.getRowsInBlock();
                                        long bclen = reblock.getColsInBlock();
-       
+                                       
+                                       //replace reblock inputs to avoid 
dangling references across dags
+                                       //(otherwise, for instance, literal ops 
are shared across dags)
+                                       for( int i=0; 
i<reblock.getInput().size(); i++ )
+                                               if( reblock.getInput().get(i) 
instanceof LiteralOp )
+                                                       
HopRewriteUtils.replaceChildReference(reblock, reblock.getInput().get(i), 
+                                                               new 
LiteralOp((LiteralOp)reblock.getInput().get(i)));
+                                       
                                        //create new transient read
                                        DataOp tread = new 
DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(),
                                    DataOpTypes.TRANSIENTREAD, null, rlen, 
clen, nnz, update, brlen, bclen);
@@ -83,8 +97,7 @@ public class RewriteSplitDagUnknownCSVRead extends 
StatementBlockRewriteRule
                                        
                                        //replace reblock with transient read
                                        ArrayList<Hop> parents = new 
ArrayList<Hop>(reblock.getParent());
-                                       for( int i=0; i<parents.size(); i++ )
-                                       {
+                                       for( int i=0; i<parents.size(); i++ ) {
                                                Hop parent = parents.get(i);
                                                
HopRewriteUtils.replaceChildReference(parent, reblock, tread);
                                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
index 95e28e7..4387362 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
@@ -102,8 +102,6 @@ import 
org.apache.sysml.runtime.instructions.cp.DoubleObject;
 import org.apache.sysml.runtime.instructions.cp.IntObject;
 import org.apache.sysml.runtime.instructions.cp.StringObject;
 import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
-import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
-import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
 import org.apache.sysml.runtime.io.IOUtilFunctions;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index fb179f5..bc603ba 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -20,7 +20,6 @@
 package org.apache.sysml.runtime.controlprogram.context;
 
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java 
b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index ba61dc0..195968a 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -176,6 +176,7 @@ public abstract class GPUTests extends AutomatedTestBase {
                                                                "Relative 
error(%f) is more than threshold (%f). Expected = %f, Actual = %f, differed at 
[%d, %d]",
                                                                relativeError, 
getTHRESHOLD(), expectedDouble, actualDouble, i, j);
                                                
Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD());
+                                               format.close();
                                        } else {
                                                
Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD());
                                        }

Reply via email to