Repository: systemml Updated Branches: refs/heads/master f58717564 -> bc16b9e3d
[SYSTEMML-1662] Fix rewrite issues discovered by hop validator, part 2 This patch addresses additional issues discovered by the extended hop dag validator. Specifically this includes (1) proper handling of visit status in dag-splitting rewrites (by avoiding dangling references across dags), and (2) proper parent-child linking on rewriting persistent to transient reads and writes (mlcontext, jmlc). Furthermore, this includes minor fixes for the hop dag validator (regarding the expected number of inputs to function ops) and GPU-related compiler warnings. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/bc16b9e3 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/bc16b9e3 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/bc16b9e3 Branch: refs/heads/master Commit: bc16b9e3db290edbb62e9686998435f8eea66be3 Parents: f587175 Author: Matthias Boehm <[email protected]> Authored: Sat Jun 10 00:46:19 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 10 13:52:53 2017 -0700 ---------------------------------------------------------------------- .../sysml/api/mlcontext/ScriptExecutor.java | 30 ++++---------------- src/main/java/org/apache/sysml/hops/DataOp.java | 5 ++-- .../java/org/apache/sysml/hops/LiteralOp.java | 17 +++++++---- .../sysml/hops/rewrite/HopDagValidator.java | 6 ++-- .../RewriteSplitDagDataDependentOperators.java | 8 +++++- .../rewrite/RewriteSplitDagUnknownCSVRead.java | 25 ++++++++++++---- .../controlprogram/ParForProgramBlock.java | 2 -- .../context/ExecutionContext.java | 1 - .../org/apache/sysml/test/gpu/GPUTests.java | 1 + 9 files changed, 50 insertions(+), 45 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java index cd884f2..0035350 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -153,9 +153,7 @@ public class ScriptExecutor { protected void constructHops() { try { dmlTranslator.constructHops(dmlProgram); - } catch (LanguageException e) { - throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e); - } catch (ParseException e) { + } catch (LanguageException | ParseException e) { throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e); } } @@ -168,11 +166,7 @@ public class ScriptExecutor { protected void rewriteHops() { try { dmlTranslator.rewriteHopsDAG(dmlProgram); - } catch (LanguageException e) { - throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e); - } catch (HopsException e) { - throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e); - } catch (ParseException e) { + } catch (LanguageException | HopsException | ParseException e) { throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e); } } @@ -199,13 +193,7 @@ public class ScriptExecutor { protected void constructLops() { try { dmlTranslator.constructLops(dmlProgram); - } catch (ParseException e) { - throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e); - } catch (LanguageException e) { - throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e); - } catch (HopsException e) { - throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e); - } catch (LopsException e) { + } catch (ParseException | LanguageException | HopsException | LopsException e) { throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e); } } @@ -218,13 +206,7 @@ public class ScriptExecutor { protected void generateRuntimeProgram() { try { runtimeProgram = dmlProgram.getRuntimeProgram(config); - } catch (LanguageException e) { - throw new MLContextException("Exception occurred while generating runtime program", e); - } catch (DMLRuntimeException e) { - throw new MLContextException("Exception occurred while generating runtime program", e); - } catch (LopsException e) { - throw new MLContextException("Exception occurred while generating runtime program", e); - } catch (IOException e) { + } catch (LanguageException | DMLRuntimeException | LopsException | IOException e) { throw new MLContextException("Exception occurred while generating runtime program", e); } } @@ -480,9 +462,7 @@ public class ScriptExecutor { ProgramRewriter programRewriter = new ProgramRewriter(rewrite); try { programRewriter.rewriteProgramHopDAGs(dmlProgram); - } catch (LanguageException e) { - throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e); - } catch (HopsException e) { + } catch (LanguageException | HopsException e) { throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/DataOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DataOp.java b/src/main/java/org/apache/sysml/hops/DataOp.java index bba0898..8653360 100644 --- a/src/main/java/org/apache/sysml/hops/DataOp.java +++ b/src/main/java/org/apache/sysml/hops/DataOp.java @@ -592,11 +592,10 @@ public class DataOp extends Hop * @param inputName The name of the input to remove */ public void removeInput(String inputName) { - int inputIndex = getParameterIndex(inputName); - _input.remove(inputIndex); + Hop tmp = _input.remove(inputIndex); + tmp._parent.remove(this); _paramIndexMap.remove(inputName); - for (Entry<String, Integer> entry : _paramIndexMap.entrySet()) { if (entry.getValue() > inputIndex) { _paramIndexMap.put(entry.getKey(), (entry.getValue() - 1)); http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/LiteralOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/LiteralOp.java b/src/main/java/org/apache/sysml/hops/LiteralOp.java index 503d270..b96d032 100644 --- a/src/main/java/org/apache/sysml/hops/LiteralOp.java +++ b/src/main/java/org/apache/sysml/hops/LiteralOp.java @@ -30,7 +30,6 @@ import org.apache.sysml.runtime.util.UtilFunctions; public class LiteralOp extends Hop { - private double value_double = Double.NaN; private long value_long = Long.MAX_VALUE; private String value_string; @@ -44,22 +43,30 @@ public class LiteralOp extends Hop public LiteralOp(double value) { super(String.valueOf(value), DataType.SCALAR, ValueType.DOUBLE); - this.value_double = value; + value_double = value; } public LiteralOp(long value) { super(String.valueOf(value), DataType.SCALAR, ValueType.INT); - this.value_long = value; + value_long = value; } public LiteralOp(String value) { super(value, DataType.SCALAR, ValueType.STRING); - this.value_string = value; + value_string = value; } public LiteralOp(boolean value) { super(String.valueOf(value), DataType.SCALAR, ValueType.BOOLEAN); - this.value_boolean = value; + value_boolean = value; + } + + public LiteralOp(LiteralOp that) { + super(that.getName(), that.getDataType(), that.getValueType()); + value_double = that.value_double; + value_long = that.value_long; + value_string = that.value_string; + value_boolean = that.value_boolean; } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java index 52fb36f..8cb5e1e 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java @@ -27,6 +27,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.DataOp; +import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.LiteralOp; @@ -85,6 +86,7 @@ public class HopDagValidator { private static void rValidateHop(final Hop hop, final ValidatorState state) throws HopsException { final long id = hop.getHopID(); + //check visit status final boolean seen = !state.seen.add(id); check(seen == hop.isVisited(), hop, "seen previously is %b but does not match hop visit status", seen); @@ -107,8 +109,8 @@ public class HopDagValidator { //check empty children (other variable-length Hops must have at least one child) if( input.isEmpty() ) - check(hop instanceof DataOp || hop instanceof LiteralOp, hop, - "is not a dataop/literal but has no children"); + check(hop instanceof DataOp || hop instanceof FunctionOp || hop instanceof LiteralOp, hop, + "is not a dataop/functionop/literal but has no children"); // check Hop has a legal arity (number of children) hop.checkArity(); http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index 85d34dc..6c2fda9 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -20,8 +20,11 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.DataOp; @@ -63,7 +66,6 @@ import org.apache.sysml.runtime.matrix.data.Pair; */ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewriteRule { - private static String _varnamePredix = "_sbcvar"; private static IDSequence _seq = new IDSequence(); @@ -71,6 +73,10 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException { + //DAG splits not required for forced single node + if( DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE ) + return new ArrayList<StatementBlock>(Arrays.asList(sb)); + ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); //collect all unknown csv reads hops http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java index 8396813..2e9847a 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java @@ -20,12 +20,16 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; +import java.util.Arrays; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.FileFormatTypes; import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.VariableSet; @@ -45,6 +49,10 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException { + //DAG splits not required for forced single node + if( DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE ) + return new ArrayList<StatementBlock>(Arrays.asList(sb)); + ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); //collect all unknown csv reads hops @@ -66,16 +74,22 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule //move csv reads incl reblock to new statement block //(and replace original persistent read with transient read) ArrayList<Hop> sb1hops = new ArrayList<Hop>(); - for( Hop c : cand ) + for( Hop reblock : cand ) { - Hop reblock = c; long rlen = reblock.getDim1(); long clen = reblock.getDim2(); long nnz = reblock.getNnz(); - UpdateType update = c.getUpdateType(); + UpdateType update = reblock.getUpdateType(); long brlen = reblock.getRowsInBlock(); long bclen = reblock.getColsInBlock(); - + + //replace reblock inputs to avoid dangling references across dags + //(otherwise, for instance, literal ops are shared across dags) + for( int i=0; i<reblock.getInput().size(); i++ ) + if( reblock.getInput().get(i) instanceof LiteralOp ) + HopRewriteUtils.replaceChildReference(reblock, reblock.getInput().get(i), + new LiteralOp((LiteralOp)reblock.getInput().get(i))); + //create new transient read DataOp tread = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(), DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen); @@ -83,8 +97,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule //replace reblock with transient read ArrayList<Hop> parents = new ArrayList<Hop>(reblock.getParent()); - for( int i=0; i<parents.size(); i++ ) - { + for( int i=0; i<parents.size(); i++ ) { Hop parent = parents.get(i); HopRewriteUtils.replaceChildReference(parent, reblock, tread); } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java index 95e28e7..4387362 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java @@ -102,8 +102,6 @@ import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.StringObject; import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; -import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; -import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.io.IOUtilFunctions; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.OutputInfo; http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index fb179f5..bc603ba 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -20,7 +20,6 @@ package org.apache.sysml.runtime.controlprogram.context; import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; import java.util.List; http://git-wip-us.apache.org/repos/asf/systemml/blob/bc16b9e3/src/test/java/org/apache/sysml/test/gpu/GPUTests.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java index ba61dc0..195968a 100644 --- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java +++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java @@ -176,6 +176,7 @@ public abstract class GPUTests extends AutomatedTestBase { "Relative error(%f) is more than threshold (%f). Expected = %f, Actual = %f, differed at [%d, %d]", relativeError, getTHRESHOLD(), expectedDouble, actualDouble, i, j); Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD()); + format.close(); } else { Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD()); }
