Repository: systemml
Updated Branches:
  refs/heads/master 4f29b3485 -> 323dd72a8


[SYSTEMML-1903] Fix robustness codegen row ops w/ unknowns

This patch fixes special cases of codegen row templates with partial
unknowns, which is important for robustness during initial compilation
even though the unknowns led to dynamic recompilation during runtime.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/323dd72a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/323dd72a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/323dd72a

Branch: refs/heads/master
Commit: 323dd72a8ed18687aa3019387c4ab7b0598bd9d5
Parents: 4f29b34
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Oct 19 15:07:54 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Oct 19 16:06:14 2017 -0700

----------------------------------------------------------------------
 .../hops/codegen/template/TemplateRow.java      | 38 ++++++++++----------
 .../hops/codegen/template/TemplateUtils.java    |  2 +-
 2 files changed, 21 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/323dd72a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 0389983..e664b9f 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -250,7 +250,7 @@ public class TemplateRow extends TemplateBase
                        else if (((AggUnaryOp)hop).getDirection() == 
Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
                                //vector add without temporary copy
                                if( cdata1 instanceof CNodeBinary && 
((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() )
-                                       out = new 
CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), 
+                                       out = new 
CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1),
                                                        
((CNodeBinary)cdata1).getType().getVectorAddPrimitive());
                                else    
                                        out = cdata1;
@@ -269,7 +269,7 @@ public class TemplateRow extends TemplateBase
                        {
                                //correct input under transpose
                                cdata1 = TemplateUtils.skipTranspose(cdata1, 
hop.getInput().get(0), tmp, compileLiterals);
-                               inHops.remove(hop.getInput().get(0)); 
+                               inHops.remove(hop.getInput().get(0));
                                
inHops.add(hop.getInput().get(0).getInput().get(0));
                                
                                //note: vectorMultAdd applicable to 
vector-scalar, and vector-vector
@@ -310,7 +310,8 @@ public class TemplateRow extends TemplateBase
                        CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());
                        
                        // if one input is a matrix then we need to do vector 
by scalar operations
-                       if(hop.getInput().get(0).getDim1() > 1 && 
hop.getInput().get(0).getDim2() > 1 ) 
+                       if(hop.getInput().get(0).getDim1() > 1 && 
hop.getInput().get(0).getDim2() > 1 
+                               || (!hop.dimsKnown() && 
cdata1.getDataType()==DataType.MATRIX ) ) 
                        {
                                if( HopRewriteUtils.isUnary(hop, 
SUPPORTED_VECT_UNARY) ) {
                                        String opname = 
"VECT_"+((UnaryOp)hop).getOp().name();
@@ -320,12 +321,11 @@ public class TemplateRow extends TemplateBase
                                }
                                else 
                                        throw new RuntimeException("Unsupported 
unary matrix "
-                                                       + "operation: " + 
((UnaryOp)hop).getOp().name());
+                                               + "operation: " + 
((UnaryOp)hop).getOp().name());
                        }
                        else //general scalar case
                        {
                                cdata1 = 
TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
-                               
                                String primitiveOpName = 
((UnaryOp)hop).getOp().toString();
                                out = new CNodeUnary(cdata1, 
UnaryType.valueOf(primitiveOpName));
                        }
@@ -355,7 +355,9 @@ public class TemplateRow extends TemplateBase
                        
                        // if one input is a matrix then we need to do vector 
by scalar operations
                        if( (hop.getInput().get(0).getDim1() > 1 && 
hop.getInput().get(0).getDim2() > 1)
-                               || (hop.getInput().get(1).getDim1() > 1 && 
hop.getInput().get(1).getDim2() > 1))
+                               || (hop.getInput().get(1).getDim1() > 1 && 
hop.getInput().get(1).getDim2() > 1)
+                               || (!(hop.dimsKnown() && 
hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) 
+                                               && 
(cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix())))
                        {
                                if( HopRewriteUtils.isBinary(hop, 
SUPPORTED_VECT_BINARY) ) {
                                        if( TemplateUtils.isMatrix(cdata1) && 
(TemplateUtils.isMatrix(cdata2) 
@@ -371,14 +373,14 @@ public class TemplateRow extends TemplateBase
                                                        cdata2 = new 
CNodeUnary(cdata2, UnaryType.LOOKUP_R);
                                                out = new CNodeBinary(cdata1, 
cdata2, BinType.valueOf(opname));
                                        }
-                                       if( cdata1 instanceof CNodeData && 
inHops2.isEmpty() 
+                                       if( cdata1 instanceof CNodeData && 
inHops2.isEmpty()
                                                && 
!(cdata1.getDataType()==DataType.SCALAR) ) {
                                                inHops2.put("X", 
hop.getInput().get(0));
                                        }
                                }
                                else 
                                        throw new RuntimeException("Unsupported 
binary matrix "
-                                                       + "operation: " + 
((BinaryOp)hop).getOp().name());
+                                               + "operation: " + 
((BinaryOp)hop).getOp().name());
                        }
                        else //one input is a vector/scalar other is a scalar
                        {
@@ -389,7 +391,7 @@ public class TemplateRow extends TemplateBase
                                        || 
(TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData
                                                && 
hop.getInput().get(1).getDataType().isMatrix()))
                                        cdata2 = new CNodeUnary(cdata2, 
UnaryType.LOOKUP_R);
-                               out = new CNodeBinary(cdata1, cdata2, 
BinType.valueOf(primitiveOpName));        
+                               out = new CNodeBinary(cdata1, cdata2, 
BinType.valueOf(primitiveOpName));
                        }
                }
                else if(hop instanceof TernaryOp) 
@@ -405,16 +407,16 @@ public class TemplateRow extends TemplateBase
                        
                        //construct ternary cnode, primitive operation derived 
from OpOp3
                        out = new CNodeTernary(cdata1, cdata2, cdata3, 
-                                       
TernaryType.valueOf(top.getOp().toString()));
+                               TernaryType.valueOf(top.getOp().toString()));
                }
-               else if( hop instanceof ParameterizedBuiltinOp ) 
+               else if( hop instanceof ParameterizedBuiltinOp )
                {
                        CNode cdata1 = 
tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
                        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, 
hop.getInput().get(0));
                        
                        CNode cdata2 = 
tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
                        CNode cdata3 = 
tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
-                       TernaryType ttype = (cdata2.isLiteral() && 
cdata2.getVarname().equals("Double.NaN")) ? 
+                       TernaryType ttype = (cdata2.isLiteral() && 
cdata2.getVarname().equals("Double.NaN")) ?
                                        TernaryType.REPLACE_NAN : 
TernaryType.REPLACE;
                        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
                }
@@ -422,7 +424,7 @@ public class TemplateRow extends TemplateBase
                {
                        CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());
                        out = new CNodeTernary(cdata1, 
-                               TemplateUtils.createCNodeData(new 
LiteralOp(hop.getInput().get(0).getDim2()), true), 
+                               TemplateUtils.createCNodeData(new 
LiteralOp(hop.getInput().get(0).getDim2()), true),
                                
TemplateUtils.createCNodeData(hop.getInput().get(4), true),
                                (!hop.dimsKnown()||hop.getDim2()>1) ? 
TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
                }
@@ -456,13 +458,13 @@ public class TemplateRow extends TemplateBase
                
                @Override
                public int compare(Hop h1, Hop h2) {
-                       long ncells1 = h1.isScalar() ? Long.MIN_VALUE : 
-                               (h1==_X) ? Long.MAX_VALUE : (h1==_B1) ? 
Long.MAX_VALUE-1 : 
+                       long ncells1 = h1.isScalar() ? Long.MIN_VALUE :
+                               (h1==_X) ? Long.MAX_VALUE : (h1==_B1) ? 
Long.MAX_VALUE-1 :
                                h1.dimsKnown() ? h1.getLength() : 
Long.MAX_VALUE-2;
-                       long ncells2 = h2.isScalar() ? Long.MIN_VALUE : 
-                               (h2==_X) ? Long.MAX_VALUE : (h2==_B1) ? 
Long.MAX_VALUE-1 : 
+                       long ncells2 = h2.isScalar() ? Long.MIN_VALUE :
+                               (h2==_X) ? Long.MAX_VALUE : (h2==_B1) ? 
Long.MAX_VALUE-1 :
                                h2.dimsKnown() ? h2.getLength() : 
Long.MAX_VALUE-2;
-                       return (ncells1 > ncells2) ? -1 : (ncells1 < ncells2) ? 
1 : 0; 
+                       return (ncells1 > ncells2) ? -1 : (ncells1 < ncells2) ? 
1 : 0;
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/323dd72a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 497dae0..96e15cb 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -184,7 +184,7 @@ public class TemplateUtils
        public static RowType getRowType(Hop output, Hop... inputs) {
                Hop X = inputs[0];
                Hop B1 = (inputs.length>1) ? inputs[1] : null;
-               if( (X!=null && HopRewriteUtils.isEqualSize(output, X)) || 
X==null )
+               if( (X!=null && HopRewriteUtils.isEqualSize(output, X)) || 
X==null || !X.dimsKnown() )
                        return RowType.NO_AGG;
                else if( ((B1!=null && output.getDim1()==X.getDim1() && 
output.getDim2()==B1.getDim2())
                        || (output instanceof IndexingOp && 
HopRewriteUtils.isColumnRangeIndexing((IndexingOp)output)))

Reply via email to