Repository: systemml
Updated Branches:
  refs/heads/master 827d73bd5 -> 430c04d59


[SYSTEMML-2475] Fix matrix/frame left indexing into list data types

This patch fixes the missing left indexing support for frames and
matrices into lists. Furthermore, this also includes a robustness fix
for inferring the output data type of builtin functions when the target
is a list left indexing (which propagated incorrectly to the source).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f2c0d13e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f2c0d13e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f2c0d13e

Branch: refs/heads/master
Commit: f2c0d13e2d785ae9143f38cc59b06f14b7dd4fc8
Parents: 827d73b
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Mon Jul 30 21:58:37 2018 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Mon Jul 30 21:58:37 2018 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/lops/PartialAggregate.java |  8 +++
 .../org/apache/sysml/parser/DMLTranslator.java  | 55 +++++++++-----------
 .../context/ExecutionContext.java               |  4 ++
 .../cp/ListIndexingCPInstruction.java           |  9 ++++
 4 files changed, 47 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f2c0d13e/src/main/java/org/apache/sysml/lops/PartialAggregate.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
index 8358a1d..fa7b966 100644
--- a/src/main/java/org/apache/sysml/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
@@ -377,6 +377,14 @@ public class PartialAggregate extends Lop
                                }
                        }
                        
+                       case SumProduct: {
+                               switch( dir ) {
+                                       case RowCol: return "ua+*";
+                                       case Row:    return "uar+*";
+                                       case Col:    return "uac+*";
+                               }
+                       }
+                       
                        case Max: {
                                if( dir == DirectionTypes.RowCol ) 
                                        return "uamax";

http://git-wip-us.apache.org/repos/asf/systemml/blob/f2c0d13e/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index bdfdf8f..e9b643e 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2325,10 +2325,7 @@ public class DMLTranslator
                }
                
                Hop currBuiltinOp = null;
-
-               if (target == null) {
-                       target = createTarget(source);
-               }
+               target = (target == null) ? createTarget(source) : target;
                
                // Construct the hop based on the type of Builtin function
                switch (source.getOpCode()) {
@@ -2344,15 +2341,15 @@ public class DMLTranslator
                case COLMEAN:
                case COLPROD:
                case COLVAR:
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
                                
AggOp.valueOf(source.getOpCode().name().substring(3)), Direction.Col, expr);
                        break;
 
                case COLSD:
                        // colStdDevs = sqrt(colVariances)
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(),
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX,
                                        target.getValueType(), AggOp.VAR, 
Direction.Col, expr);
-                       currBuiltinOp = new UnaryOp(target.getName(), 
target.getDataType(),
+                       currBuiltinOp = new UnaryOp(target.getName(), 
DataType.MATRIX,
                                        target.getValueType(), Hop.OpOp1.SQRT, 
currBuiltinOp);
                        break;
 
@@ -2362,25 +2359,25 @@ public class DMLTranslator
                case ROWMEAN:
                case ROWPROD:
                case ROWVAR:
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
                                
AggOp.valueOf(source.getOpCode().name().substring(3)), Direction.Row, expr);
                        break;
 
                case ROWINDEXMAX:
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(), AggOp.MAXINDEX,
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(), AggOp.MAXINDEX,
                                        Direction.Row, expr);
                        break;
                
                case ROWINDEXMIN:
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(), AggOp.MININDEX,
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(), AggOp.MININDEX,
                                        Direction.Row, expr);
                        break;
                
                case ROWSD:
                        // rowStdDevs = sqrt(rowVariances)
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(),
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX,
                                        target.getValueType(), AggOp.VAR, 
Direction.Row, expr);
-                       currBuiltinOp = new UnaryOp(target.getName(), 
target.getDataType(),
+                       currBuiltinOp = new UnaryOp(target.getName(), 
DataType.MATRIX,
                                        target.getValueType(), Hop.OpOp1.SQRT, 
currBuiltinOp);
                        break;
 
@@ -2409,38 +2406,38 @@ public class DMLTranslator
                        break;
                        
                case EXISTS:
-                       currBuiltinOp = new UnaryOp(target.getName(), 
target.getDataType(),
+                       currBuiltinOp = new UnaryOp(target.getName(), 
DataType.SCALAR,
                                target.getValueType(), Hop.OpOp1.EXISTS, expr);
                        break;
                
                case SUM:
                case PROD:
                case VAR:
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.SCALAR, target.getValueType(),
                                AggOp.valueOf(source.getOpCode().name()), 
Direction.RowCol, expr);
                        break;
 
                case MEAN:
                        if ( expr2 == null ) {
                                // example: x = mean(Y);
-                               currBuiltinOp = new 
AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), 
AggOp.MEAN,
+                               currBuiltinOp = new 
AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.MEAN,
                                        Direction.RowCol, expr);
                        }
                        else {
                                // example: x = mean(Y,W);
                                // stable weighted mean is implemented by using 
centralMoment with order = 0
                                Hop orderHop = new LiteralOp(0);
-                               currBuiltinOp=new TernaryOp(target.getName(), 
target.getDataType(), target.getValueType(), 
+                               currBuiltinOp=new TernaryOp(target.getName(), 
DataType.SCALAR, target.getValueType(), 
                                                Hop.OpOp3.MOMENT, expr, expr2, 
orderHop);
                        }
                        break;
 
                case SD:
                        // stdDev = sqrt(variance)
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(),
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.SCALAR,
                                        target.getValueType(), AggOp.VAR, 
Direction.RowCol, expr);
                        
HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp);
-                       currBuiltinOp = new UnaryOp(target.getName(), 
target.getDataType(),
+                       currBuiltinOp = new UnaryOp(target.getName(), 
DataType.SCALAR,
                                        target.getValueType(), Hop.OpOp1.SQRT, 
currBuiltinOp);
                        break;
 
@@ -2448,7 +2445,7 @@ public class DMLTranslator
                case MAX:
                        //construct AggUnary for min(X) but BinaryOp for 
min(X,Y) and NaryOp for min(X,Y,Z)
                        currBuiltinOp = (expr2 == null) ? 
-                               new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                               new AggUnaryOp(target.getName(), 
DataType.SCALAR, target.getValueType(),
                                        
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr) : 
                                (source.getAllExpr().length == 2) ?
                                new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(),
@@ -2480,14 +2477,14 @@ public class DMLTranslator
                        break;
                        
                case TRACE:
-                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(), AggOp.TRACE,
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.SCALAR, target.getValueType(), AggOp.TRACE,
                                        Direction.RowCol, expr);
                        break;
 
                case TRANS:
                case DIAG:
                case REV:
-                       currBuiltinOp = new ReorgOp(target.getName(), 
target.getDataType(),
+                       currBuiltinOp = new ReorgOp(target.getName(), 
DataType.MATRIX,
                                target.getValueType(), 
ReOrgOp.valueOf(source.getOpCode().name()), expr);
                        break;
                        
@@ -2546,16 +2543,16 @@ public class DMLTranslator
                        }
                        break;
 
-               //data type casts       
+               //data type casts
                case CAST_AS_SCALAR:
                        currBuiltinOp = new UnaryOp(target.getName(), 
DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr);
                        break;
                case CAST_AS_MATRIX:
                        currBuiltinOp = new UnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr);
-                       break;  
+                       break;
                case CAST_AS_FRAME:
                        currBuiltinOp = new UnaryOp(target.getName(), 
DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr);
-                       break;  
+                       break;
 
                //value type casts
                case CAST_AS_DOUBLE:
@@ -2725,7 +2722,7 @@ public class DMLTranslator
                        if( op == null )
                                throw new HopsException("Unsupported outer 
vector binary operation: "+((LiteralOp)expr3).getStringValue());
                        
-                       currBuiltinOp = new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), op, expr, expr2);
+                       currBuiltinOp = new BinaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(), op, expr, expr2);
                        
((BinaryOp)currBuiltinOp).setOuterVectorOperation(true); //flag op as specific 
outer vector operation
                        currBuiltinOp.refreshSizeInformation(); //force size 
reevaluation according to 'outer' flag otherwise danger of incorrect dims
                        break;
@@ -2735,21 +2732,21 @@ public class DMLTranslator
                        ArrayList<Hop> inHops1 = new ArrayList<>();
                        inHops1.add(expr);
                        inHops1.add(expr2);
-                       currBuiltinOp = new DnnOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                       currBuiltinOp = new DnnOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
                                OpOpDnn.valueOf(source.getOpCode().name()), 
inHops1);
                        setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                        break;
                }
                case AVG_POOL:
                case MAX_POOL: {
-                       currBuiltinOp = new DnnOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                       currBuiltinOp = new DnnOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
                                OpOpDnn.valueOf(source.getOpCode().name()), 
getALHopsForPoolingForwardIM2COL(expr, source, 1, hops));
                        setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                        break;
                }
                case AVG_POOL_BACKWARD:
                case MAX_POOL_BACKWARD: {
-                       currBuiltinOp = new DnnOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                       currBuiltinOp = new DnnOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
                                OpOpDnn.valueOf(source.getOpCode().name()), 
getALHopsForConvOpPoolingCOL2IM(expr, source, 1, hops));
                        setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                        break;
@@ -2757,7 +2754,7 @@ public class DMLTranslator
                case CONV2D:
                case CONV2D_BACKWARD_FILTER:
                case CONV2D_BACKWARD_DATA: {
-                       currBuiltinOp = new DnnOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                       currBuiltinOp = new DnnOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
                                OpOpDnn.valueOf(source.getOpCode().name()), 
getALHopsForConvOp(expr, source, 1, hops));
                        setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                        break;

http://git-wip-us.apache.org/repos/asf/systemml/blob/f2c0d13e/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index 2e0addf..88ec092 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -225,6 +225,10 @@ public class ExecutionContext {
                return (FrameObject) dat;
        }
 
+       public CacheableData<?> getCacheableData(CPOperand input) {
+               return getCacheableData(input.getName());
+       }
+       
        public CacheableData<?> getCacheableData(String varname) {
                Data dat = getVariable(varname);
                //error handling if non existing or no matrix

http://git-wip-us.apache.org/repos/asf/systemml/blob/f2c0d13e/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java
index f41601f..d22d9a8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java
@@ -23,6 +23,7 @@ import org.apache.sysml.lops.LeftIndex;
 import org.apache.sysml.lops.RightIndex;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 
 public final class ListIndexingCPInstruction extends IndexingCPInstruction {
@@ -76,6 +77,14 @@ public final class ListIndexingCPInstruction extends 
IndexingCPInstruction {
                                else
                                        ec.setVariable(output.getName(), 
lin.copy().set((int)rl.getLongValue()-1, scalar));
                        }
+                       else if( input2.getDataType().isMatrix() ) { //LIST <- 
MATRIX/FRAME
+                               CacheableData<?> dat = 
ec.getCacheableData(input2);
+                               dat.enableCleanup(false);
+                               if( rl.getValueType()==ValueType.STRING )
+                                       ec.setVariable(output.getName(), 
lin.copy().set(rl.getStringValue(), dat));
+                               else
+                                       ec.setVariable(output.getName(), 
lin.copy().set((int)rl.getLongValue()-1, dat));
+                       }
                        else {
                                throw new DMLRuntimeException("Unsupported list 
"
                                        + "left indexing rhs type: 
"+input2.getDataType().name());

Reply via email to