[SYSTEMML-1407] Extended code generator (right indexing in cell/rowagg)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2893e1ae
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2893e1ae
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2893e1ae

Branch: refs/heads/master
Commit: 2893e1aed03f9259fdd63504483484da92d761af
Parents: 9cbaf85
Author: Matthias Boehm <[email protected]>
Authored: Sun Mar 19 01:41:18 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Mar 19 01:41:18 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/IndexingOp.java  |  8 +++--
 .../java/org/apache/sysml/hops/LiteralOp.java   |  6 ++--
 .../sysml/hops/codegen/SpoofCompiler.java       | 34 +++++++++++++++++---
 .../sysml/hops/codegen/cplan/CNodeCell.java     |  2 +-
 .../sysml/hops/codegen/cplan/CNodeTernary.java  |  8 ++++-
 .../sysml/hops/codegen/template/CellTpl.java    | 13 +++++++-
 .../sysml/hops/codegen/template/RowAggTpl.java  | 34 ++++++++++++++------
 .../RewriteAlgebraicSimplificationStatic.java   |  4 +--
 .../rewrite/RewriteForLoopVectorization.java    | 16 ++++-----
 .../rewrite/RewriteIndexingVectorization.java   |  8 ++---
 .../functions/codegen/RowAggTmplTest.java       | 29 +++++++++++------
 .../scripts/functions/codegen/rowAggPattern6.R  | 32 ++++++++++++++++++
 .../functions/codegen/rowAggPattern6.dml        | 29 +++++++++++++++++
 13 files changed, 177 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/IndexingOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/IndexingOp.java 
b/src/main/java/org/apache/sysml/hops/IndexingOp.java
index 7edbbea..b77947f 100644
--- a/src/main/java/org/apache/sysml/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysml/hops/IndexingOp.java
@@ -76,11 +76,11 @@ public class IndexingOp extends Hop
        }
        
        
-       public boolean getRowLowerEqualsUpper(){
+       public boolean isRowLowerEqualsUpper(){
                return _rowLowerEqualsUpper;
        }
        
-       public boolean getColLowerEqualsUpper() {
+       public boolean isColLowerEqualsUpper() {
                return _colLowerEqualsUpper;
        }
        
@@ -397,6 +397,10 @@ public class IndexingOp extends Hop
                Hop input4 = getInput().get(3); //inpColL
                Hop input5 = getInput().get(4); //inpColU
                
+               //update single row/column flags (depends on CSE)
+               _rowLowerEqualsUpper = (input2 == input3);
+               _colLowerEqualsUpper = (input4 == input5);
+               
                //parse input information
                boolean allRows = 
                        (    input2 instanceof LiteralOp && 
HopRewriteUtils.getIntValueSafe((LiteralOp)input2)==1 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/LiteralOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/LiteralOp.java 
b/src/main/java/org/apache/sysml/hops/LiteralOp.java
index 835b7ab..e089177 100644
--- a/src/main/java/org/apache/sysml/hops/LiteralOp.java
+++ b/src/main/java/org/apache/sysml/hops/LiteralOp.java
@@ -183,7 +183,7 @@ public class LiteralOp extends Hop
                //do nothing; it is a scalar
        }
        
-       public long getLongValue() throws HopsException 
+       public long getLongValue() 
        {
                switch( getValueType() ) {
                        case INT:               
@@ -192,8 +192,10 @@ public class LiteralOp extends Hop
                                return UtilFunctions.toLong(value_double);
                        case STRING:
                                return Long.parseLong(value_string);    
+                       case BOOLEAN: 
+                               return value_boolean ? 1 : 0;
                        default:
-                               throw new HopsException("Can not coerce an 
object of type " + getValueType() + " into Long.");
+                               return -1;
                }
        }
        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 6479917..f1dfb91 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -36,6 +36,8 @@ import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeCell;
 import org.apache.sysml.hops.codegen.cplan.CNodeData;
 import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct;
+import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
+import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
 import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
@@ -490,8 +492,15 @@ public class SpoofCompiler
                        
                        //remove spurious lookups on main input of cell template
                        if( tpl instanceof CNodeCell || tpl instanceof 
CNodeOuterProduct ) {
-                               CNode in1 = tpl.getInput().get(0);
-                               rFindAndRemoveLookup(tpl.getOutput(), 
in1.getVarname());
+                               CNodeData in1 = 
(CNodeData)tpl.getInput().get(0);
+                               rFindAndRemoveLookup(tpl.getOutput(), in1);
+                       }
+                       
+                       //remove invalid plans with column indexing on main 
input
+                       if( tpl instanceof CNodeCell ) {
+                               CNodeData in1 = 
(CNodeData)tpl.getInput().get(0);
+                               if( rHasLookupRC1(tpl.getOutput(), in1) )
+                                       cplans2.remove(e.getKey());
                        }
                        
                        //remove cplan w/ single op and w/o agg
@@ -517,17 +526,32 @@ public class SpoofCompiler
                        rCollectLeafIDs(c, leafs);
        }
        
-       private static void rFindAndRemoveLookup(CNode node, String nodeName) {
+       private static void rFindAndRemoveLookup(CNode node, CNodeData 
mainInput) {
                for( int i=0; i<node.getInput().size(); i++ ) {
                        CNode tmp = node.getInput().get(i);
                        if( tmp instanceof CNodeUnary && 
(((CNodeUnary)tmp).getType()==UnaryType.LOOKUP_R 
                                        || 
((CNodeUnary)tmp).getType()==UnaryType.LOOKUP_RC)
-                               && 
tmp.getInput().get(0).getVarname().equals(nodeName) )
+                               && tmp.getInput().get(0) instanceof CNodeData
+                               && 
((CNodeData)tmp.getInput().get(0)).getHopID()==mainInput.getHopID() )
                        {
                                node.getInput().set(i, tmp.getInput().get(0));
                        }
                        else
-                               rFindAndRemoveLookup(tmp, nodeName);
+                               rFindAndRemoveLookup(tmp, mainInput);
+               }
+       }
+       
+       private static boolean rHasLookupRC1(CNode node, CNodeData mainInput) {
+               boolean ret = false;
+               for( int i=0; i<node.getInput().size() && !ret; i++ ) {
+                       CNode tmp = node.getInput().get(i);
+                       if( tmp instanceof CNodeTernary && 
((CNodeTernary)tmp).getType()==TernaryType.LOOKUP_RC1 
+                               && tmp.getInput().get(0) instanceof CNodeData
+                               && 
((CNodeData)tmp.getInput().get(0)).getHopID() == mainInput.getHopID())
+                               ret = true;
+                       else
+                               ret |= rHasLookupRC1(tmp, mainInput);
                }
+               return ret;
        }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
index caf7b6a..527da28 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
@@ -173,7 +173,7 @@ public class CNodeCell extends CNodeTpl
                StringBuilder sb = new StringBuilder();
                sb.append("SPOOF CELLWISE [type=");
                sb.append(_type.name());
-               sb.append(", spafeSafe="+_sparseSafe);
+               sb.append(", sparseSafe="+_sparseSafe);
                sb.append(", castdtm="+_requiresCastdtm);
                sb.append(", mc="+_multipleConsumers);
                sb.append("]");

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
index 9fdae3b..c9b389d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -27,7 +27,8 @@ import org.apache.sysml.parser.Expression.DataType;
 public class CNodeTernary extends CNode
 {
        public enum TernaryType {
-               PLUS_MULT, MINUS_MULT;
+               PLUS_MULT, MINUS_MULT,
+               LOOKUP_RC1;
                
                public static boolean contains(String value) {
                        for( TernaryType tt : values()  )
@@ -44,6 +45,9 @@ public class CNodeTernary extends CNode
                                case MINUS_MULT:
                                        return "    double %TMP% = %IN1% - 
%IN2% * %IN3%;\n;\n" ;
                                        
+                               case LOOKUP_RC1:
+                                       return "    double %TMP% = 
%IN1%[rowIndex*%IN2%+%IN3%-1];\n";   
+                                       
                                default: 
                                        throw new RuntimeException("Invalid 
ternary type: "+this.toString());
                        }
@@ -97,6 +101,7 @@ public class CNodeTernary extends CNode
                switch(_type) {
                        case PLUS_MULT: return "t(+*)";
                        case MINUS_MULT: return "t(-*)";
+                       case LOOKUP_RC1: return "u(ixrc1)";
                        default:
                                return super.toString();        
                }
@@ -107,6 +112,7 @@ public class CNodeTernary extends CNode
                switch(_type) {
                        case PLUS_MULT: 
                        case MINUS_MULT:
+                       case LOOKUP_RC1:
                                _rows = 0;
                                _cols = 0;
                                _dataType= DataType.SCALAR;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
index f5d11d1..c645eed 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
@@ -32,6 +32,8 @@ import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
@@ -56,7 +58,8 @@ public class CellTpl extends BaseTpl
 
        @Override
        public boolean open(Hop hop) {
-               return isValidOperation(hop);
+               return isValidOperation(hop)
+                       || (hop instanceof IndexingOp && 
((IndexingOp)hop).isColLowerEqualsUpper());
        }
 
        @Override
@@ -197,6 +200,14 @@ public class CellTpl extends BaseTpl
                        out = new CNodeTernary(cdata1, cdata2, cdata3, 
                                        
TernaryType.valueOf(top.getOp().toString()));
                }
+               else if( hop instanceof IndexingOp ) 
+               {
+                       CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());
+                       out = new CNodeTernary(cdata1, 
+                                       TemplateUtils.createCNodeData(new 
LiteralOp(hop.getInput().get(0).getDim2()), true), 
+                                       
TemplateUtils.createCNodeData(hop.getInput().get(4), true),
+                                       TernaryType.LOOKUP_RC1);
+               }
                else if( HopRewriteUtils.isTransposeOperation(hop) ) 
                {
                        out = tmp.get(hop.getInput().get(0).getHopID());        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java
index 9fa5efd..1aa380b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java
@@ -28,11 +28,15 @@ import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
+import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
 import org.apache.sysml.hops.codegen.cplan.CNodeData;
 import org.apache.sysml.hops.codegen.cplan.CNodeRowAgg;
+import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
 import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
@@ -90,18 +94,16 @@ public class RowAggTpl extends BaseTpl {
        public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable 
memo, boolean compileLiterals) {
                //recursively process required cplan output
                HashSet<Hop> inHops = new HashSet<Hop>();
+               HashMap<String, Hop> inHops2 = new HashMap<String,Hop>();
                HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
                hop.resetVisitStatus();
-               rConstructCplan(hop, memo, tmp, inHops, compileLiterals);
+               rConstructCplan(hop, memo, tmp, inHops, inHops2, 
compileLiterals);
                hop.resetVisitStatus();
                
                //reorder inputs (ensure matrix is first input)
                LinkedList<Hop> sinHops = new LinkedList<Hop>(inHops);
-               for( Hop h : inHops )
-                       if( h.getDataType().isMatrix() && 
!TemplateUtils.isVector(h) ) {
-                               sinHops.remove(h);
-                               sinHops.addFirst(h);
-                       }
+               Hop X = inHops2.get("X");
+               sinHops.remove(X); sinHops.addFirst(X);
                
                //construct template node
                ArrayList<CNode> inputs = new ArrayList<CNode>();
@@ -114,14 +116,14 @@ public class RowAggTpl extends BaseTpl {
                return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), 
tpl);
        }
 
-       private void rConstructCplan(Hop hop, CPlanMemoTable memo, 
HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) 
+       private void rConstructCplan(Hop hop, CPlanMemoTable memo, 
HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, 
boolean compileLiterals) 
        {       
                //recursively process required childs
                MemoTableEntry me = memo.getBest(hop.getHopID(), 
TemplateType.RowAggTpl);
                for( int i=0; i<hop.getInput().size(); i++ ) {
                        Hop c = hop.getInput().get(i);
                        if( me.isPlanRef(i) )
-                               rConstructCplan(c, memo, tmp, inHops, 
compileLiterals);
+                               rConstructCplan(c, memo, tmp, inHops, inHops2, 
compileLiterals);
                        else {
                                CNodeData cdata = 
TemplateUtils.createCNodeData(c, compileLiterals);    
                                tmp.put(c.getHopID(), cdata);
@@ -137,8 +139,10 @@ public class RowAggTpl extends BaseTpl {
                        if(  ((AggUnaryOp)hop).getDirection() == Direction.Row 
&& ((AggUnaryOp)hop).getOp() == AggOp.SUM  ) {
                                if(hop.getInput().get(0).getDim2()==1)
                                        out = 
(cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new 
CNodeUnary(cdata1,UnaryType.LOOKUP_R);
-                               else
+                               else {
                                        out = new CNodeUnary(cdata1, 
UnaryType.ROW_SUMS);
+                                       inHops2.put("X", hop.getInput().get(0));
+                               }
                        }
                        else  if (((AggUnaryOp)hop).getDirection() == 
Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
                                //vector div add without temporary copy
@@ -167,8 +171,10 @@ public class RowAggTpl extends BaseTpl {
                                if(hop.getInput().get(0).getDim2()==1 && 
hop.getInput().get(1).getDim2()==1)
                                        out = new 
CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new 
CNodeUnary(cdata1, UnaryType.LOOKUP0),
                                                
(cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, 
UnaryType.LOOKUP0), BinType.MULT);
-                               else    
+                               else {
                                        out = new CNodeBinary(cdata1, cdata2, 
BinType.DOT_PRODUCT);
+                                       inHops2.put("X", hop.getInput().get(0));
+                               }
                        }
                }
                else if(hop instanceof BinaryOp)
@@ -194,6 +200,14 @@ public class RowAggTpl extends BaseTpl {
                                out = new CNodeBinary(cdata1, cdata2, 
BinType.valueOf(primitiveOpName));        
                        }
                }
+               else if( hop instanceof IndexingOp ) 
+               {
+                       CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());
+                       out = new CNodeTernary(cdata1, 
+                                       TemplateUtils.createCNodeData(new 
LiteralOp(hop.getInput().get(0).getDim2()), true), 
+                                       
TemplateUtils.createCNodeData(hop.getInput().get(4), true),
+                                       TernaryType.LOOKUP_RC1);
+               }
                
                if( out.getDataType().isMatrix() ) {
                        out.setNumRows(hop.getDim1());

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 5af850f..3345ee1 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1234,8 +1234,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
        {
                //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] 
                if( hi instanceof IndexingOp 
-                       && ((IndexingOp)hi).getRowLowerEqualsUpper()
-                       && ((IndexingOp)hi).getColLowerEqualsUpper()  
+                       && ((IndexingOp)hi).isRowLowerEqualsUpper()
+                       && ((IndexingOp)hi).isColLowerEqualsUpper()  
                        && hi.getInput().get(0).getParent().size()==1 //rix is 
single mm consumer
                        && 
HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) )
                {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
index 273436e..e3e55fe 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
@@ -133,12 +133,12 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        && right.getInput().get(0) instanceof 
IndexingOp )
                                {
                                        IndexingOp ix = 
(IndexingOp)right.getInput().get(0);
-                                       if( ix.getRowLowerEqualsUpper() && 
ix.getInput().get(1) instanceof DataOp
+                                       if( ix.isRowLowerEqualsUpper() && 
ix.getInput().get(1) instanceof DataOp
                                                && 
ix.getInput().get(1).getName().equals(itervar) ){
                                                leftScalar = true;
                                                rowIx = true;
                                        }
-                                       else if( ix.getColLowerEqualsUpper() && 
ix.getInput().get(3) instanceof DataOp
+                                       else if( ix.isColLowerEqualsUpper() && 
ix.getInput().get(3) instanceof DataOp
                                                && 
ix.getInput().get(3).getName().equals(itervar) ){
                                                leftScalar = true;
                                                rowIx = false;
@@ -152,12 +152,12 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        && left.getInput().get(0) instanceof 
IndexingOp )
                                {
                                        IndexingOp ix = 
(IndexingOp)left.getInput().get(0);
-                                       if( ix.getRowLowerEqualsUpper() && 
ix.getInput().get(1) instanceof DataOp
+                                       if( ix.isRowLowerEqualsUpper() && 
ix.getInput().get(1) instanceof DataOp
                                                && 
ix.getInput().get(1).getName().equals(itervar) ){
                                                rightScalar = true;
                                                rowIx = true;
                                        }
-                                       else if( ix.getColLowerEqualsUpper() && 
ix.getInput().get(3) instanceof DataOp
+                                       else if( ix.isColLowerEqualsUpper() && 
ix.getInput().get(3) instanceof DataOp
                                                && 
ix.getInput().get(3).getName().equals(itervar) ){
                                                rightScalar = true;
                                                rowIx = false;
@@ -236,7 +236,7 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        IndexingOp rix1 = (IndexingOp) 
lixrhs.getInput().get(1);
                                        
                                        //check for rowwise
-                                       if(    lix.getRowLowerEqualsUpper() && 
rix0.getRowLowerEqualsUpper() && rix1.getRowLowerEqualsUpper() 
+                                       if(    lix.getRowLowerEqualsUpper() && 
rix0.isRowLowerEqualsUpper() && rix1.isRowLowerEqualsUpper() 
                                                && 
lix.getInput().get(2).getName().equals(itervar)
                                                && 
rix0.getInput().get(1).getName().equals(itervar)
                                                && 
rix1.getInput().get(1).getName().equals(itervar))
@@ -245,7 +245,7 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                                rowIx = true;
                                        }
                                        //check for colwise
-                                       if(    lix.getColLowerEqualsUpper() && 
rix0.getColLowerEqualsUpper() && rix1.getColLowerEqualsUpper() 
+                                       if(    lix.getColLowerEqualsUpper() && 
rix0.isColLowerEqualsUpper() && rix1.isColLowerEqualsUpper() 
                                                && 
lix.getInput().get(4).getName().equals(itervar)
                                                && 
rix0.getInput().get(3).getName().equals(itervar)
                                                && 
rix1.getInput().get(3).getName().equals(itervar))
@@ -406,14 +406,14 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                boolean[] ret = new boolean[2]; //apply, rowIx
                
                //check for rowwise
-               if(    lix.getRowLowerEqualsUpper() && 
rix.getRowLowerEqualsUpper() 
+               if(    lix.getRowLowerEqualsUpper() && 
rix.isRowLowerEqualsUpper() 
                        && lix.getInput().get(2).getName().equals(itervar)
                        && rix.getInput().get(1).getName().equals(itervar) ) {
                        ret[0] = true;
                        ret[1] = true;
                }
                //check for colwise
-               if(    lix.getColLowerEqualsUpper() && 
rix.getColLowerEqualsUpper() 
+               if(    lix.getColLowerEqualsUpper() && 
rix.isColLowerEqualsUpper() 
                        && lix.getInput().get(4).getName().equals(itervar)
                        && rix.getInput().get(3).getName().equals(itervar) ) {
                        ret[0] = true;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java
index f5af292..4ce1d43 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java
@@ -107,8 +107,8 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                if( hop instanceof IndexingOp ) //right indexing
                {
                        IndexingOp ihop0 = (IndexingOp) hop;
-                       boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
-                       boolean isSingleCol = ihop0.getColLowerEqualsUpper();
+                       boolean isSingleRow = ihop0.isRowLowerEqualsUpper();
+                       boolean isSingleCol = ihop0.isColLowerEqualsUpper();
                        boolean appliedRow = false;
                        
                        //search for multiple indexing in same row
@@ -120,7 +120,7 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                ihops.add(ihop0);
                                for( Hop c : input.getParent() ){
                                        if( c != ihop0 && c instanceof 
IndexingOp && c.getInput().get(0) == input
-                                          && ((IndexingOp) 
c).getRowLowerEqualsUpper() 
+                                          && ((IndexingOp) 
c).isRowLowerEqualsUpper() 
                                           && 
c.getInput().get(1)==ihop0.getInput().get(1) )
                                        {
                                                ihops.add( c );
@@ -159,7 +159,7 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                ihops.add(ihop0);
                                for( Hop c : input.getParent() ){
                                        if( c != ihop0 && c instanceof 
IndexingOp && c.getInput().get(0) == input
-                                          && ((IndexingOp) 
c).getColLowerEqualsUpper() 
+                                          && ((IndexingOp) 
c).isColLowerEqualsUpper() 
                                           && 
c.getInput().get(3)==ihop0.getInput().get(3) )
                                        {
                                                ihops.add( c );

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 064c987..101bad8 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -35,11 +35,13 @@ import org.apache.sysml.test.utils.TestUtils;
 
 public class RowAggTmplTest extends AutomatedTestBase 
 {
-       private static final String TEST_NAME1 = "rowAggPattern1";
-       private static final String TEST_NAME2 = "rowAggPattern2";
-       private static final String TEST_NAME3 = "rowAggPattern3";
-       private static final String TEST_NAME4 = "rowAggPattern4";
-       private static final String TEST_NAME5 = "rowAggPattern5";
+       private static final String TEST_NAME = "rowAggPattern";
+       private static final String TEST_NAME1 = TEST_NAME+"1";
+       private static final String TEST_NAME2 = TEST_NAME+"2";
+       private static final String TEST_NAME3 = TEST_NAME+"3";
+       private static final String TEST_NAME4 = TEST_NAME+"4";
+       private static final String TEST_NAME5 = TEST_NAME+"5";
+       private static final String TEST_NAME6 = TEST_NAME+"6";
 
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
@@ -51,11 +53,8 @@ public class RowAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "0" }) );
-               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "1" }) );
-               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "2" }) );
-               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "3" }) );
-               addTestConfiguration( TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "4" }) );
+               for(int i=1; i<=6; i++)
+                       addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
        @Test   
@@ -83,6 +82,11 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME5, true, ExecType.CP );        
        }
        
+       @Test
+       public void testCodegenRowAggRewrite6() {
+               testCodegenIntegration( TEST_NAME6, true, ExecType.CP );        
+       }
+       
        @Test   
        public void testCodegenRowAgg1() {
                testCodegenIntegration( TEST_NAME1, false, ExecType.CP );
@@ -108,6 +112,11 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME5, false, ExecType.CP );       
        }
        
+       @Test
+       public void testCodegenRowAgg6() {
+               testCodegenIntegration( TEST_NAME6, false, ExecType.CP );       
+       }
+       
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {       
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/test/scripts/functions/codegen/rowAggPattern6.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern6.R 
b/src/test/scripts/functions/codegen/rowAggPattern6.R
new file mode 100644
index 0000000..af64a5f
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern6.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(seq(1,15), 5, 3, byrow=TRUE);
+v = seq(1,3);
+P = cbind(seq(1,5),seq(2,6));
+
+S = t(X) %*% ((P[,1] * (1-P[,1])) * (X %*% v));
+
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); 
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/test/scripts/functions/codegen/rowAggPattern6.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern6.dml 
b/src/test/scripts/functions/codegen/rowAggPattern6.dml
new file mode 100644
index 0000000..e0521c8
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern6.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1,15), rows=5, cols=3);
+v = seq(1,3);
+P = cbind(seq(1,5),seq(2,6));
+
+S = t(X) %*% ((P[,1] * (1-P[,1])) * (X %*% v));
+
+write(S,$1)
+

Reply via email to