Repository: systemml
Updated Branches:
  refs/heads/master 5dee6c7ed -> 4916d454b


[SYSTEMML-2476] Extended literal replacement for scalar list lookups 

This patch extends the literal replacement rewrites during dynamic
recompilation to also handle scalar list lookups (scalar cast over list
lookup) which is important for hyper parameters that affect sizes of
subsequent operations.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/4916d454
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/4916d454
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/4916d454

Branch: refs/heads/master
Commit: 4916d454ba370f76abacf3c6bdd6e2c526917fd9
Parents: 5dee6c7
Author: Matthias Boehm <[email protected]>
Authored: Tue Jul 31 22:32:53 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Jul 31 22:33:28 2018 -0700

----------------------------------------------------------------------
 .../hops/recompile/LiteralReplacement.java      | 35 ++++++++++++++++----
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 32 ++++++++----------
 .../org/apache/sysml/parser/Expression.java     |  5 ++-
 .../paramserv/mnist_lenet_paramserv.dml         |  2 +-
 4 files changed, 47 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/4916d454/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java 
b/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java
index b2344c5..a8b3024 100644
--- a/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java
+++ b/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java
@@ -36,7 +36,6 @@ import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.lops.compile.Dag;
 import org.apache.sysml.parser.Expression.DataType;
-import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
@@ -78,6 +77,7 @@ public class LiteralReplacement
                                        lit = (lit==null) ? 
replaceLiteralFullUnaryAggregateRightIndexing(c, vars) : lit;
                                        lit = (lit==null) ? 
replaceTReadMatrixFromList(c, vars) : lit;
                                        lit = (lit==null) ? 
replaceTReadMatrixLookupFromList(c, vars) : lit;
+                                       lit = (lit==null) ? 
replaceTReadScalarLookupFromList(c, vars) : lit;
                                }
                                
                                //replace hop w/ literal on demand
@@ -167,17 +167,17 @@ public class LiteralReplacement
                                switch( cast.getOp() ) {
                                        case CAST_AS_INT:
                                                long ival = 
HopRewriteUtils.getIntValue(sdat);
-                                               ret = new LiteralOp(ival);      
        
+                                               ret = new LiteralOp(ival);
                                                break;
                                        case CAST_AS_DOUBLE:
                                                double dval = 
HopRewriteUtils.getDoubleValue(sdat);
-                                               ret = new LiteralOp(dval);      
        
-                                               break;                          
                
+                                               ret = new LiteralOp(dval);
+                                               break;
                                        case CAST_AS_BOOLEAN:
                                                boolean bval = 
HopRewriteUtils.getBooleanValue(sdat);
-                                               ret = new LiteralOp(bval);      
        
+                                               ret = new LiteralOp(bval);
                                                break;
-                                       default:        
+                                       default:
                                                //otherwise: do nothing
                                }
                        }
@@ -368,7 +368,7 @@ public class LiteralReplacement
                                ListObject list = 
(ListObject)vars.get(ixIn.getName());
                                String varname = 
Dag.getNextUniqueVarname(DataType.MATRIX);
                                LiteralOp lit = (LiteralOp) 
ix.getInput().get(1);
-                               MatrixObject mo = (MatrixObject) 
((lit.getValueType() == ValueType.STRING) ?
+                               MatrixObject mo = (MatrixObject) 
(!lit.getValueType().isNumeric() ?
                                        list.slice(lit.getName()) : 
list.slice((int)lit.getLongValue()-1));
                                vars.put(varname, mo);
                                ret = 
HopRewriteUtils.createTransientRead(varname, c);
@@ -377,6 +377,27 @@ public class LiteralReplacement
                return ret;
        }
        
+       private static LiteralOp replaceTReadScalarLookupFromList( Hop c, 
LocalVariableMap vars ) {
+               //pattern: as.scalar(X[i:i]) or as.scalar(X['a','a']) with X 
being a list
+               if( HopRewriteUtils.isUnary(c, OpOp1.CAST_AS_SCALAR)
+                       && c.getInput().get(0) instanceof IndexingOp ) {
+                       Hop ix = c.getInput().get(0);
+                       Hop ixIn = c.getInput().get(0).getInput().get(0);
+                       if( ixIn.getDataType() == DataType.LIST
+                               && HopRewriteUtils.isData(ixIn, 
DataOpTypes.TRANSIENTREAD)
+                               && ix.getInput().get(1) instanceof LiteralOp 
+                               && ix.getInput().get(2) instanceof LiteralOp
+                               && ix.getInput().get(1) == ix.getInput().get(2) 
) {
+                               ListObject list = 
(ListObject)vars.get(ixIn.getName());
+                               LiteralOp lit = (LiteralOp) 
ix.getInput().get(1);
+                               ScalarObject so = (ScalarObject) 
(!lit.getValueType().isNumeric() ?
+                                       list.slice(lit.getName()) : 
list.slice((int)lit.getLongValue()-1));
+                               return ScalarObjectFactory.createLiteralOp(so);
+                       }
+               }
+               return null;
+       }
+       
        ///////////////////////////////
        // Utility functions
        ///////////////////////////////

http://git-wip-us.apache.org/repos/asf/systemml/blob/4916d454/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 025f98a..4af5a10 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -91,26 +91,20 @@ public class HopRewriteUtils
        // literal handling
 
        public static boolean getBooleanValue( LiteralOp op ) {
-               switch( op.getValueType() )
-               {
+               switch( op.getValueType() ) {
                        case DOUBLE:  return op.getDoubleValue() != 0; 
-                       case INT:         return op.getLongValue()   != 0;
+                       case INT:     return op.getLongValue()   != 0;
                        case BOOLEAN: return op.getBooleanValue();
-                       
                        default: throw new HopsException("Invalid boolean 
value: "+op.getValueType());
                }
        }
 
-       public static boolean getBooleanValueSafe( LiteralOp op )
-       {
-               try
-               {
-                       switch( op.getValueType() )
-                       {
+       public static boolean getBooleanValueSafe( LiteralOp op ) {
+               try {
+                       switch( op.getValueType() ) {
                                case DOUBLE:  return op.getDoubleValue() != 0; 
-                               case INT:         return op.getLongValue()   != 
0;
+                               case INT:     return op.getLongValue()   != 0;
                                case BOOLEAN: return op.getBooleanValue();
-                               
                                default: throw new HopsException("Invalid 
boolean value: "+op.getValueType());
                        }
                }
@@ -123,8 +117,9 @@ public class HopRewriteUtils
 
        public static double getDoubleValue( LiteralOp op ) {
                switch( op.getValueType() ) {
+                       case STRING:
                        case DOUBLE:  return op.getDoubleValue(); 
-                       case INT:         return op.getLongValue();
+                       case INT:     return op.getLongValue();
                        case BOOLEAN: return op.getBooleanValue() ? 1 : 0;
                        default: throw new HopsException("Invalid double value: 
"+op.getValueType());
                }
@@ -133,7 +128,7 @@ public class HopRewriteUtils
        public static double getDoubleValueSafe( LiteralOp op ) {
                switch( op.getValueType() ) {
                        case DOUBLE:  return op.getDoubleValue(); 
-                       case INT:         return op.getLongValue();
+                       case INT:     return op.getLongValue();
                        case BOOLEAN: return op.getBooleanValue() ? 1 : 0;
                        default: return Double.MAX_VALUE;
                }
@@ -151,8 +146,9 @@ public class HopRewriteUtils
         */
        public static long getIntValue( LiteralOp op ) {
                switch( op.getValueType() ) {
-                       case DOUBLE:  return 
UtilFunctions.toLong(op.getDoubleValue()); 
-                       case INT:         return op.getLongValue();
+                       case DOUBLE:  return 
UtilFunctions.toLong(op.getDoubleValue());
+                       case STRING:
+                       case INT:     return op.getLongValue();
                        case BOOLEAN: return op.getBooleanValue() ? 1 : 0;
                        default: throw new HopsException("Invalid int value: 
"+op.getValueType());
                }
@@ -160,8 +156,8 @@ public class HopRewriteUtils
        
        public static long getIntValueSafe( LiteralOp op ) {
                switch( op.getValueType() ) {
-                       case DOUBLE:  return 
UtilFunctions.toLong(op.getDoubleValue()); 
-                       case INT:         return op.getLongValue();
+                       case DOUBLE:  return 
UtilFunctions.toLong(op.getDoubleValue());
+                       case INT:     return op.getLongValue();
                        case BOOLEAN: return op.getBooleanValue() ? 1 : 0;
                        default: return Long.MAX_VALUE;
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/4916d454/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java 
b/src/main/java/org/apache/sysml/parser/Expression.java
index 8299fbe..46e6442 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -205,7 +205,10 @@ public abstract class Expression implements ParseInfo
         * Value types (int, double, string, boolean, object, unknown).
         */
        public enum ValueType {
-               INT, DOUBLE, STRING, BOOLEAN, OBJECT, UNKNOWN
+               INT, DOUBLE, STRING, BOOLEAN, OBJECT, UNKNOWN;
+               public boolean isNumeric() {
+                       return this == INT || this == DOUBLE;
+               }
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/systemml/blob/4916d454/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
index bce4eea..35b0bd2 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -132,7 +132,7 @@ gradients = function(matrix[double] features,
 
 # PB: not be able to get scalar from list
 
-  C = 1
+  C = as.scalar(hyperparams["C"])
   Hin = 28
   Win = 28
   Hf = 5

Reply via email to