Repository: systemml Updated Branches: refs/heads/master 5dee6c7ed -> 4916d454b
[SYSTEMML-2476] Extended literal replacement for scalar list lookups This patch extends the literal replacement rewrites during dynamic recompilation to also handle scalar list lookups (scalar cast over list lookup) which is important for hyper parameters that affect sizes of subsequent operations. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/4916d454 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/4916d454 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/4916d454 Branch: refs/heads/master Commit: 4916d454ba370f76abacf3c6bdd6e2c526917fd9 Parents: 5dee6c7 Author: Matthias Boehm <[email protected]> Authored: Tue Jul 31 22:32:53 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Jul 31 22:33:28 2018 -0700 ---------------------------------------------------------------------- .../hops/recompile/LiteralReplacement.java | 35 ++++++++++++++++---- .../sysml/hops/rewrite/HopRewriteUtils.java | 32 ++++++++---------- .../org/apache/sysml/parser/Expression.java | 5 ++- .../paramserv/mnist_lenet_paramserv.dml | 2 +- 4 files changed, 47 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/4916d454/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java b/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java index b2344c5..a8b3024 100644 --- a/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java +++ b/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java @@ -36,7 +36,6 @@ import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.lops.compile.Dag; import org.apache.sysml.parser.Expression.DataType; -import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; @@ -78,6 +77,7 @@ public class LiteralReplacement lit = (lit==null) ? replaceLiteralFullUnaryAggregateRightIndexing(c, vars) : lit; lit = (lit==null) ? replaceTReadMatrixFromList(c, vars) : lit; lit = (lit==null) ? replaceTReadMatrixLookupFromList(c, vars) : lit; + lit = (lit==null) ? replaceTReadScalarLookupFromList(c, vars) : lit; } //replace hop w/ literal on demand @@ -167,17 +167,17 @@ public class LiteralReplacement switch( cast.getOp() ) { case CAST_AS_INT: long ival = HopRewriteUtils.getIntValue(sdat); - ret = new LiteralOp(ival); + ret = new LiteralOp(ival); break; case CAST_AS_DOUBLE: double dval = HopRewriteUtils.getDoubleValue(sdat); - ret = new LiteralOp(dval); - break; + ret = new LiteralOp(dval); + break; case CAST_AS_BOOLEAN: boolean bval = HopRewriteUtils.getBooleanValue(sdat); - ret = new LiteralOp(bval); + ret = new LiteralOp(bval); break; - default: + default: //otherwise: do nothing } } @@ -368,7 +368,7 @@ public class LiteralReplacement ListObject list = (ListObject)vars.get(ixIn.getName()); String varname = Dag.getNextUniqueVarname(DataType.MATRIX); LiteralOp lit = (LiteralOp) ix.getInput().get(1); - MatrixObject mo = (MatrixObject) ((lit.getValueType() == ValueType.STRING) ? + MatrixObject mo = (MatrixObject) (!lit.getValueType().isNumeric() ? list.slice(lit.getName()) : list.slice((int)lit.getLongValue()-1)); vars.put(varname, mo); ret = HopRewriteUtils.createTransientRead(varname, c); @@ -377,6 +377,27 @@ public class LiteralReplacement return ret; } + private static LiteralOp replaceTReadScalarLookupFromList( Hop c, LocalVariableMap vars ) { + //pattern: as.scalar(X[i:i]) or as.scalar(X['a','a']) with X being a list + if( HopRewriteUtils.isUnary(c, OpOp1.CAST_AS_SCALAR) + && c.getInput().get(0) instanceof IndexingOp ) { + Hop ix = c.getInput().get(0); + Hop ixIn = c.getInput().get(0).getInput().get(0); + if( ixIn.getDataType() == DataType.LIST + && HopRewriteUtils.isData(ixIn, DataOpTypes.TRANSIENTREAD) + && ix.getInput().get(1) instanceof LiteralOp + && ix.getInput().get(2) instanceof LiteralOp + && ix.getInput().get(1) == ix.getInput().get(2) ) { + ListObject list = (ListObject)vars.get(ixIn.getName()); + LiteralOp lit = (LiteralOp) ix.getInput().get(1); + ScalarObject so = (ScalarObject) (!lit.getValueType().isNumeric() ? + list.slice(lit.getName()) : list.slice((int)lit.getLongValue()-1)); + return ScalarObjectFactory.createLiteralOp(so); + } + } + return null; + } + /////////////////////////////// // Utility functions /////////////////////////////// http://git-wip-us.apache.org/repos/asf/systemml/blob/4916d454/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 025f98a..4af5a10 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -91,26 +91,20 @@ public class HopRewriteUtils // literal handling public static boolean getBooleanValue( LiteralOp op ) { - switch( op.getValueType() ) - { + switch( op.getValueType() ) { case DOUBLE: return op.getDoubleValue() != 0; - case INT: return op.getLongValue() != 0; + case INT: return op.getLongValue() != 0; case BOOLEAN: return op.getBooleanValue(); - default: throw new HopsException("Invalid boolean value: "+op.getValueType()); } } - public static boolean getBooleanValueSafe( LiteralOp op ) - { - try - { - switch( op.getValueType() ) - { + public static boolean getBooleanValueSafe( LiteralOp op ) { + try { + switch( op.getValueType() ) { case DOUBLE: return op.getDoubleValue() != 0; - case INT: return op.getLongValue() != 0; + case INT: return op.getLongValue() != 0; case BOOLEAN: return op.getBooleanValue(); - default: throw new HopsException("Invalid boolean value: "+op.getValueType()); } } @@ -123,8 +117,9 @@ public class HopRewriteUtils public static double getDoubleValue( LiteralOp op ) { switch( op.getValueType() ) { + case STRING: case DOUBLE: return op.getDoubleValue(); - case INT: return op.getLongValue(); + case INT: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: throw new HopsException("Invalid double value: "+op.getValueType()); } @@ -133,7 +128,7 @@ public class HopRewriteUtils public static double getDoubleValueSafe( LiteralOp op ) { switch( op.getValueType() ) { case DOUBLE: return op.getDoubleValue(); - case INT: return op.getLongValue(); + case INT: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: return Double.MAX_VALUE; } @@ -151,8 +146,9 @@ public class HopRewriteUtils */ public static long getIntValue( LiteralOp op ) { switch( op.getValueType() ) { - case DOUBLE: return UtilFunctions.toLong(op.getDoubleValue()); - case INT: return op.getLongValue(); + case DOUBLE: return UtilFunctions.toLong(op.getDoubleValue()); + case STRING: + case INT: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: throw new HopsException("Invalid int value: "+op.getValueType()); } @@ -160,8 +156,8 @@ public class HopRewriteUtils public static long getIntValueSafe( LiteralOp op ) { switch( op.getValueType() ) { - case DOUBLE: return UtilFunctions.toLong(op.getDoubleValue()); - case INT: return op.getLongValue(); + case DOUBLE: return UtilFunctions.toLong(op.getDoubleValue()); + case INT: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: return Long.MAX_VALUE; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4916d454/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index 8299fbe..46e6442 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -205,7 +205,10 @@ public abstract class Expression implements ParseInfo * Value types (int, double, string, boolean, object, unknown). */ public enum ValueType { - INT, DOUBLE, STRING, BOOLEAN, OBJECT, UNKNOWN + INT, DOUBLE, STRING, BOOLEAN, OBJECT, UNKNOWN; + public boolean isNumeric() { + return this == INT || this == DOUBLE; + } } /** http://git-wip-us.apache.org/repos/asf/systemml/blob/4916d454/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml index bce4eea..35b0bd2 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml @@ -132,7 +132,7 @@ gradients = function(matrix[double] features, # PB: not be able to get scalar from list - C = 1 + C = as.scalar(hyperparams["C"]) Hin = 28 Win = 28 Hf = 5
