[SYSTEMML-1407] Extended code generator (right indexing in cell/rowagg) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2893e1ae Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2893e1ae Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2893e1ae
Branch: refs/heads/master Commit: 2893e1aed03f9259fdd63504483484da92d761af Parents: 9cbaf85 Author: Matthias Boehm <[email protected]> Authored: Sun Mar 19 01:41:18 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Mar 19 01:41:18 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/IndexingOp.java | 8 +++-- .../java/org/apache/sysml/hops/LiteralOp.java | 6 ++-- .../sysml/hops/codegen/SpoofCompiler.java | 34 +++++++++++++++++--- .../sysml/hops/codegen/cplan/CNodeCell.java | 2 +- .../sysml/hops/codegen/cplan/CNodeTernary.java | 8 ++++- .../sysml/hops/codegen/template/CellTpl.java | 13 +++++++- .../sysml/hops/codegen/template/RowAggTpl.java | 34 ++++++++++++++------ .../RewriteAlgebraicSimplificationStatic.java | 4 +-- .../rewrite/RewriteForLoopVectorization.java | 16 ++++----- .../rewrite/RewriteIndexingVectorization.java | 8 ++--- .../functions/codegen/RowAggTmplTest.java | 29 +++++++++++------ .../scripts/functions/codegen/rowAggPattern6.R | 32 ++++++++++++++++++ .../functions/codegen/rowAggPattern6.dml | 29 +++++++++++++++++ 13 files changed, 177 insertions(+), 46 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/IndexingOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/IndexingOp.java b/src/main/java/org/apache/sysml/hops/IndexingOp.java index 7edbbea..b77947f 100644 --- a/src/main/java/org/apache/sysml/hops/IndexingOp.java +++ b/src/main/java/org/apache/sysml/hops/IndexingOp.java @@ -76,11 +76,11 @@ public class IndexingOp extends Hop } - public boolean getRowLowerEqualsUpper(){ + public boolean isRowLowerEqualsUpper(){ return _rowLowerEqualsUpper; } - public boolean getColLowerEqualsUpper() { + public boolean isColLowerEqualsUpper() { return _colLowerEqualsUpper; } @@ -397,6 +397,10 @@ public class IndexingOp extends Hop Hop input4 = getInput().get(3); //inpColL Hop input5 = getInput().get(4); //inpColU + //update single row/column flags (depends on CSE) + _rowLowerEqualsUpper = (input2 == input3); + _colLowerEqualsUpper = (input4 == input5); + //parse input information boolean allRows = ( input2 instanceof LiteralOp && HopRewriteUtils.getIntValueSafe((LiteralOp)input2)==1 http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/LiteralOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/LiteralOp.java b/src/main/java/org/apache/sysml/hops/LiteralOp.java index 835b7ab..e089177 100644 --- a/src/main/java/org/apache/sysml/hops/LiteralOp.java +++ b/src/main/java/org/apache/sysml/hops/LiteralOp.java @@ -183,7 +183,7 @@ public class LiteralOp extends Hop //do nothing; it is a scalar } - public long getLongValue() throws HopsException + public long getLongValue() { switch( getValueType() ) { case INT: @@ -192,8 +192,10 @@ public class LiteralOp extends Hop return UtilFunctions.toLong(value_double); case STRING: return Long.parseLong(value_string); + case BOOLEAN: + return value_boolean ? 1 : 0; default: - throw new HopsException("Can not coerce an object of type " + getValueType() + " into Long."); + return -1; } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 6479917..f1dfb91 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -36,6 +36,8 @@ import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeCell; import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; @@ -490,8 +492,15 @@ public class SpoofCompiler //remove spurious lookups on main input of cell template if( tpl instanceof CNodeCell || tpl instanceof CNodeOuterProduct ) { - CNode in1 = tpl.getInput().get(0); - rFindAndRemoveLookup(tpl.getOutput(), in1.getVarname()); + CNodeData in1 = (CNodeData)tpl.getInput().get(0); + rFindAndRemoveLookup(tpl.getOutput(), in1); + } + + //remove invalid plans with column indexing on main input + if( tpl instanceof CNodeCell ) { + CNodeData in1 = (CNodeData)tpl.getInput().get(0); + if( rHasLookupRC1(tpl.getOutput(), in1) ) + cplans2.remove(e.getKey()); } //remove cplan w/ single op and w/o agg @@ -517,17 +526,32 @@ public class SpoofCompiler rCollectLeafIDs(c, leafs); } - private static void rFindAndRemoveLookup(CNode node, String nodeName) { + private static void rFindAndRemoveLookup(CNode node, CNodeData mainInput) { for( int i=0; i<node.getInput().size(); i++ ) { CNode tmp = node.getInput().get(i); if( tmp instanceof CNodeUnary && (((CNodeUnary)tmp).getType()==UnaryType.LOOKUP_R || ((CNodeUnary)tmp).getType()==UnaryType.LOOKUP_RC) - && tmp.getInput().get(0).getVarname().equals(nodeName) ) + && tmp.getInput().get(0) instanceof CNodeData + && ((CNodeData)tmp.getInput().get(0)).getHopID()==mainInput.getHopID() ) { node.getInput().set(i, tmp.getInput().get(0)); } else - rFindAndRemoveLookup(tmp, nodeName); + rFindAndRemoveLookup(tmp, mainInput); + } + } + + private static boolean rHasLookupRC1(CNode node, CNodeData mainInput) { + boolean ret = false; + for( int i=0; i<node.getInput().size() && !ret; i++ ) { + CNode tmp = node.getInput().get(i); + if( tmp instanceof CNodeTernary && ((CNodeTernary)tmp).getType()==TernaryType.LOOKUP_RC1 + && tmp.getInput().get(0) instanceof CNodeData + && ((CNodeData)tmp.getInput().get(0)).getHopID() == mainInput.getHopID()) + ret = true; + else + ret |= rHasLookupRC1(tmp, mainInput); } + return ret; } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java index caf7b6a..527da28 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java @@ -173,7 +173,7 @@ public class CNodeCell extends CNodeTpl StringBuilder sb = new StringBuilder(); sb.append("SPOOF CELLWISE [type="); sb.append(_type.name()); - sb.append(", spafeSafe="+_sparseSafe); + sb.append(", sparseSafe="+_sparseSafe); sb.append(", castdtm="+_requiresCastdtm); sb.append(", mc="+_multipleConsumers); sb.append("]"); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java index 9fdae3b..c9b389d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java @@ -27,7 +27,8 @@ import org.apache.sysml.parser.Expression.DataType; public class CNodeTernary extends CNode { public enum TernaryType { - PLUS_MULT, MINUS_MULT; + PLUS_MULT, MINUS_MULT, + LOOKUP_RC1; public static boolean contains(String value) { for( TernaryType tt : values() ) @@ -44,6 +45,9 @@ public class CNodeTernary extends CNode case MINUS_MULT: return " double %TMP% = %IN1% - %IN2% * %IN3%;\n;\n" ; + case LOOKUP_RC1: + return " double %TMP% = %IN1%[rowIndex*%IN2%+%IN3%-1];\n"; + default: throw new RuntimeException("Invalid ternary type: "+this.toString()); } @@ -97,6 +101,7 @@ public class CNodeTernary extends CNode switch(_type) { case PLUS_MULT: return "t(+*)"; case MINUS_MULT: return "t(-*)"; + case LOOKUP_RC1: return "u(ixrc1)"; default: return super.toString(); } @@ -107,6 +112,7 @@ public class CNodeTernary extends CNode switch(_type) { case PLUS_MULT: case MINUS_MULT: + case LOOKUP_RC1: _rows = 0; _cols = 0; _dataType= DataType.SCALAR; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java index f5d11d1..c645eed 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java @@ -32,6 +32,8 @@ import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.IndexingOp; +import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; @@ -56,7 +58,8 @@ public class CellTpl extends BaseTpl @Override public boolean open(Hop hop) { - return isValidOperation(hop); + return isValidOperation(hop) + || (hop instanceof IndexingOp && ((IndexingOp)hop).isColLowerEqualsUpper()); } @Override @@ -197,6 +200,14 @@ public class CellTpl extends BaseTpl out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString())); } + else if( hop instanceof IndexingOp ) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + out = new CNodeTernary(cdata1, + TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), + TemplateUtils.createCNodeData(hop.getInput().get(4), true), + TernaryType.LOOKUP_RC1); + } else if( HopRewriteUtils.isTransposeOperation(hop) ) { out = tmp.get(hop.getInput().get(0).getHopID()); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java index 9fa5efd..1aa380b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java @@ -28,11 +28,15 @@ import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.IndexingOp; +import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.cplan.CNodeRowAgg; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; @@ -90,18 +94,16 @@ public class RowAggTpl extends BaseTpl { public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) { //recursively process required cplan output HashSet<Hop> inHops = new HashSet<Hop>(); + HashMap<String, Hop> inHops2 = new HashMap<String,Hop>(); HashMap<Long, CNode> tmp = new HashMap<Long, CNode>(); hop.resetVisitStatus(); - rConstructCplan(hop, memo, tmp, inHops, compileLiterals); + rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals); hop.resetVisitStatus(); //reorder inputs (ensure matrix is first input) LinkedList<Hop> sinHops = new LinkedList<Hop>(inHops); - for( Hop h : inHops ) - if( h.getDataType().isMatrix() && !TemplateUtils.isVector(h) ) { - sinHops.remove(h); - sinHops.addFirst(h); - } + Hop X = inHops2.get("X"); + sinHops.remove(X); sinHops.addFirst(X); //construct template node ArrayList<CNode> inputs = new ArrayList<CNode>(); @@ -114,14 +116,14 @@ public class RowAggTpl extends BaseTpl { return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); } - private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) + private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) { //recursively process required childs MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.RowAggTpl); for( int i=0; i<hop.getInput().size(); i++ ) { Hop c = hop.getInput().get(i); if( me.isPlanRef(i) ) - rConstructCplan(c, memo, tmp, inHops, compileLiterals); + rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals); else { CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); tmp.put(c.getHopID(), cdata); @@ -137,8 +139,10 @@ public class RowAggTpl extends BaseTpl { if( ((AggUnaryOp)hop).getDirection() == Direction.Row && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) { if(hop.getInput().get(0).getDim2()==1) out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R); - else + else { out = new CNodeUnary(cdata1, UnaryType.ROW_SUMS); + inHops2.put("X", hop.getInput().get(0)); + } } else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) { //vector div add without temporary copy @@ -167,8 +171,10 @@ public class RowAggTpl extends BaseTpl { if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1) out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), (cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT); - else + else { out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); + inHops2.put("X", hop.getInput().get(0)); + } } } else if(hop instanceof BinaryOp) @@ -194,6 +200,14 @@ public class RowAggTpl extends BaseTpl { out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); } } + else if( hop instanceof IndexingOp ) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + out = new CNodeTernary(cdata1, + TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), + TemplateUtils.createCNodeData(hop.getInput().get(4), true), + TernaryType.LOOKUP_RC1); + } if( out.getDataType().isMatrix() ) { out.setNumRows(hop.getDim1()); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 5af850f..3345ee1 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -1234,8 +1234,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] if( hi instanceof IndexingOp - && ((IndexingOp)hi).getRowLowerEqualsUpper() - && ((IndexingOp)hi).getColLowerEqualsUpper() + && ((IndexingOp)hi).isRowLowerEqualsUpper() + && ((IndexingOp)hi).isColLowerEqualsUpper() && hi.getInput().get(0).getParent().size()==1 //rix is single mm consumer && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java index 273436e..e3e55fe 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java @@ -133,12 +133,12 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule && right.getInput().get(0) instanceof IndexingOp ) { IndexingOp ix = (IndexingOp)right.getInput().get(0); - if( ix.getRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp + if( ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar) ){ leftScalar = true; rowIx = true; } - else if( ix.getColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp + else if( ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar) ){ leftScalar = true; rowIx = false; @@ -152,12 +152,12 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule && left.getInput().get(0) instanceof IndexingOp ) { IndexingOp ix = (IndexingOp)left.getInput().get(0); - if( ix.getRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp + if( ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar) ){ rightScalar = true; rowIx = true; } - else if( ix.getColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp + else if( ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar) ){ rightScalar = true; rowIx = false; @@ -236,7 +236,7 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule IndexingOp rix1 = (IndexingOp) lixrhs.getInput().get(1); //check for rowwise - if( lix.getRowLowerEqualsUpper() && rix0.getRowLowerEqualsUpper() && rix1.getRowLowerEqualsUpper() + if( lix.getRowLowerEqualsUpper() && rix0.isRowLowerEqualsUpper() && rix1.isRowLowerEqualsUpper() && lix.getInput().get(2).getName().equals(itervar) && rix0.getInput().get(1).getName().equals(itervar) && rix1.getInput().get(1).getName().equals(itervar)) @@ -245,7 +245,7 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule rowIx = true; } //check for colwise - if( lix.getColLowerEqualsUpper() && rix0.getColLowerEqualsUpper() && rix1.getColLowerEqualsUpper() + if( lix.getColLowerEqualsUpper() && rix0.isColLowerEqualsUpper() && rix1.isColLowerEqualsUpper() && lix.getInput().get(4).getName().equals(itervar) && rix0.getInput().get(3).getName().equals(itervar) && rix1.getInput().get(3).getName().equals(itervar)) @@ -406,14 +406,14 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule boolean[] ret = new boolean[2]; //apply, rowIx //check for rowwise - if( lix.getRowLowerEqualsUpper() && rix.getRowLowerEqualsUpper() + if( lix.getRowLowerEqualsUpper() && rix.isRowLowerEqualsUpper() && lix.getInput().get(2).getName().equals(itervar) && rix.getInput().get(1).getName().equals(itervar) ) { ret[0] = true; ret[1] = true; } //check for colwise - if( lix.getColLowerEqualsUpper() && rix.getColLowerEqualsUpper() + if( lix.getColLowerEqualsUpper() && rix.isColLowerEqualsUpper() && lix.getInput().get(4).getName().equals(itervar) && rix.getInput().get(3).getName().equals(itervar) ) { ret[0] = true; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java index f5af292..4ce1d43 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java @@ -107,8 +107,8 @@ public class RewriteIndexingVectorization extends HopRewriteRule if( hop instanceof IndexingOp ) //right indexing { IndexingOp ihop0 = (IndexingOp) hop; - boolean isSingleRow = ihop0.getRowLowerEqualsUpper(); - boolean isSingleCol = ihop0.getColLowerEqualsUpper(); + boolean isSingleRow = ihop0.isRowLowerEqualsUpper(); + boolean isSingleCol = ihop0.isColLowerEqualsUpper(); boolean appliedRow = false; //search for multiple indexing in same row @@ -120,7 +120,7 @@ public class RewriteIndexingVectorization extends HopRewriteRule ihops.add(ihop0); for( Hop c : input.getParent() ){ if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input - && ((IndexingOp) c).getRowLowerEqualsUpper() + && ((IndexingOp) c).isRowLowerEqualsUpper() && c.getInput().get(1)==ihop0.getInput().get(1) ) { ihops.add( c ); @@ -159,7 +159,7 @@ public class RewriteIndexingVectorization extends HopRewriteRule ihops.add(ihop0); for( Hop c : input.getParent() ){ if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input - && ((IndexingOp) c).getColLowerEqualsUpper() + && ((IndexingOp) c).isColLowerEqualsUpper() && c.getInput().get(3)==ihop0.getInput().get(3) ) { ihops.add( c ); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index 064c987..101bad8 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -35,11 +35,13 @@ import org.apache.sysml.test.utils.TestUtils; public class RowAggTmplTest extends AutomatedTestBase { - private static final String TEST_NAME1 = "rowAggPattern1"; - private static final String TEST_NAME2 = "rowAggPattern2"; - private static final String TEST_NAME3 = "rowAggPattern3"; - private static final String TEST_NAME4 = "rowAggPattern4"; - private static final String TEST_NAME5 = "rowAggPattern5"; + private static final String TEST_NAME = "rowAggPattern"; + private static final String TEST_NAME1 = TEST_NAME+"1"; + private static final String TEST_NAME2 = TEST_NAME+"2"; + private static final String TEST_NAME3 = TEST_NAME+"3"; + private static final String TEST_NAME4 = TEST_NAME+"4"; + private static final String TEST_NAME5 = TEST_NAME+"5"; + private static final String TEST_NAME6 = TEST_NAME+"6"; private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -51,11 +53,8 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "0" }) ); - addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "1" }) ); - addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "2" }) ); - addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "3" }) ); - addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "4" }) ); + for(int i=1; i<=6; i++) + addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @Test @@ -83,6 +82,11 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME5, true, ExecType.CP ); } + @Test + public void testCodegenRowAggRewrite6() { + testCodegenIntegration( TEST_NAME6, true, ExecType.CP ); + } + @Test public void testCodegenRowAgg1() { testCodegenIntegration( TEST_NAME1, false, ExecType.CP ); @@ -108,6 +112,11 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME5, false, ExecType.CP ); } + @Test + public void testCodegenRowAgg6() { + testCodegenIntegration( TEST_NAME6, false, ExecType.CP ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/test/scripts/functions/codegen/rowAggPattern6.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern6.R b/src/test/scripts/functions/codegen/rowAggPattern6.R new file mode 100644 index 0000000..af64a5f --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern6.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(1,15), 5, 3, byrow=TRUE); +v = seq(1,3); +P = cbind(seq(1,5),seq(2,6)); + +S = t(X) %*% ((P[,1] * (1-P[,1])) * (X %*% v)); + +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2893e1ae/src/test/scripts/functions/codegen/rowAggPattern6.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern6.dml b/src/test/scripts/functions/codegen/rowAggPattern6.dml new file mode 100644 index 0000000..e0521c8 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern6.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,15), rows=5, cols=3); +v = seq(1,3); +P = cbind(seq(1,5),seq(2,6)); + +S = t(X) %*% ((P[,1] * (1-P[,1])) * (X %*% v)); + +write(S,$1) +
