Repository: incubator-systemml Updated Branches: refs/heads/master 0dd095ffb -> b3120ce24
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3120ce2/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java new file mode 100644 index 0000000..61c73d3 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.template; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.sysml.hops.AggBinaryOp; +import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.IndexingOp; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; +import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeRowAgg; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.matrix.data.Pair; + +public class TemplateRowAgg extends TemplateBase { + + public TemplateRowAgg() { + super(TemplateType.RowAggTpl); + } + + @Override + public boolean open(Hop hop) { + //any unary or binary aggregate operation with a vector output, but exclude binary aggregate + //with transposed input to avoid counter-productive fusion + return ( ((hop instanceof AggBinaryOp && hop.getInput().get(1).getDim1()>1 + && !HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM )) + && ( (hop.getDim1()==1 && hop.getDim2()!=1) || (hop.getDim1()!=1 && hop.getDim2()==1) ) ); + } + + @Override + public boolean fuse(Hop hop, Hop input) { + return !isClosed() && + ( (hop instanceof BinaryOp && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) + || HopRewriteUtils.isBinaryMatrixScalarOperation(hop))) + || (hop instanceof UnaryOp && TemplateCell.isValidOperation(hop)) + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.Col) + || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); + } + + @Override + public boolean merge(Hop hop, Hop input) { + //merge rowagg tpl with cell tpl if input is a vector + return !isClosed() && + (hop instanceof BinaryOp && input.getDim2()==1 ); //matrix-scalar/vector-vector ops ) + } + + @Override + public CloseType close(Hop hop) { + //close on column aggregate (e.g., colSums, t(X)%*%y) + if( hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.Col + || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) ) + return CloseType.CLOSED_VALID; + else + return CloseType.OPEN; + } + + @Override + public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) { + //recursively process required cplan output + HashSet<Hop> inHops = new HashSet<Hop>(); + HashMap<String, Hop> inHops2 = new HashMap<String,Hop>(); + HashMap<Long, CNode> tmp = new HashMap<Long, CNode>(); + hop.resetVisitStatus(); + rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals); + hop.resetVisitStatus(); + + //reorder inputs (ensure matrix is first input, and other inputs ordered by size) + List<Hop> sinHops = inHops.stream() + .filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())) + .sorted(new HopInputComparator(inHops2.get("X"))).collect(Collectors.toList()); + + //construct template node + ArrayList<CNode> inputs = new ArrayList<CNode>(); + for( Hop in : sinHops ) + inputs.add(tmp.get(in.getHopID())); + CNode output = tmp.get(hop.getHopID()); + CNodeRowAgg tpl = new CNodeRowAgg(inputs, output); + + // return cplan instance + return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); + } + + private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) + { + //recursively process required childs + MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.RowAggTpl); + for( int i=0; i<hop.getInput().size(); i++ ) { + Hop c = hop.getInput().get(i); + if( me.isPlanRef(i) ) + rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals); + else { + CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); + tmp.put(c.getHopID(), cdata); + inHops.add(c); + } + } + + //construct cnode for current hop + CNode out = null; + if(hop instanceof AggUnaryOp) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + if( ((AggUnaryOp)hop).getDirection() == Direction.Row && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) { + if(hop.getInput().get(0).getDim2()==1) + out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R); + else { + out = new CNodeUnary(cdata1, UnaryType.ROW_SUMS); + inHops2.put("X", hop.getInput().get(0)); + } + } + else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) { + //vector div add without temporary copy + if(cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType()==BinType.VECT_DIV_SCALAR) + out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), BinType.VECT_DIV_ADD); + else + out = cdata1; + } + } + else if(hop instanceof AggBinaryOp) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); + + if( HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) ) + { + //correct input under transpose + cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals); + inHops.remove(hop.getInput().get(0)); + inHops.add(hop.getInput().get(0).getInput().get(0)); + + out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD); + } + else + { + if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1) + out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), + (cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT); + else { + out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT); + inHops2.put("X", hop.getInput().get(0)); + } + } + } + else if(hop instanceof UnaryOp) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); + else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); + + String primitiveOpName = ((UnaryOp)hop).getOp().toString(); + out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); + } + else if(hop instanceof BinaryOp) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); + + // if one input is a matrix then we need to do vector by scalar operations + if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 ) + { + if (((BinaryOp)hop).getOp()== OpOp2.DIV) + out = new CNodeBinary(cdata1, cdata2, BinType.VECT_DIV_SCALAR); + } + else //one input is a vector/scalar other is a scalar + { + String primitiveOpName = ((BinaryOp)hop).getOp().toString(); + if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) ) { + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); + } + if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) ) { + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); + } + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); + } + } + else if( hop instanceof IndexingOp ) + { + CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); + out = new CNodeTernary(cdata1, + TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), + TemplateUtils.createCNodeData(hop.getInput().get(4), true), + TernaryType.LOOKUP_RC1); + } + + if( out.getDataType().isMatrix() ) { + out.setNumRows(hop.getDim1()); + out.setNumCols(hop.getDim2()); + } + + tmp.put(hop.getHopID(), out); + } + + /** + * Comparator to order input hops of the row aggregate template. We try + * to order matrices-vectors-scalars via sorting by number of cells but + * we keep the given main input always at the first position. + */ + public static class HopInputComparator implements Comparator<Hop> + { + private final Hop _X; + + public HopInputComparator(Hop X) { + _X = X; + } + + @Override + public int compare(Hop h1, Hop h2) { + long ncells1 = h1.getDataType()==DataType.SCALAR ? Long.MIN_VALUE : + (h1==_X) ? Long.MAX_VALUE : + h1.dimsKnown() ? h1.getDim1()*h1.getDim2() : Long.MAX_VALUE-1; + long ncells2 = h2.getDataType()==DataType.SCALAR ? Long.MIN_VALUE : + (h2==_X) ? Long.MAX_VALUE : + h2.dimsKnown() ? h2.getDim1()*h2.getDim2() : Long.MAX_VALUE-1; + return (ncells1 > ncells2) ? -1 : (ncells1 < ncells2) ? 1 : 0; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3120ce2/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 6c2fe8d..4f9a5c7 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -40,7 +40,7 @@ import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; -import org.apache.sysml.hops.codegen.template.BaseTpl.TemplateType; +import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; @@ -51,7 +51,7 @@ import org.apache.sysml.runtime.util.UtilFunctions; public class TemplateUtils { - public static final BaseTpl[] TEMPLATES = new BaseTpl[]{new RowAggTpl(), new CellTpl(), new OuterProductTpl()}; + public static final TemplateBase[] TEMPLATES = new TemplateBase[]{new TemplateRowAgg(), new TemplateCell(), new TemplateOuterProduct()}; public static boolean isVector(Hop hop) { return (hop.getDataType() == DataType.MATRIX @@ -174,16 +174,16 @@ public class TemplateUtils return ret; } - public static BaseTpl createTemplate(TemplateType type) { + public static TemplateBase createTemplate(TemplateType type) { return createTemplate(type, false); } - public static BaseTpl createTemplate(TemplateType type, boolean closed) { - BaseTpl tpl = null; + public static TemplateBase createTemplate(TemplateType type, boolean closed) { + TemplateBase tpl = null; switch( type ) { - case CellTpl: tpl = new CellTpl(); break; - case RowAggTpl: tpl = new RowAggTpl(); break; - case OuterProdTpl: tpl = new OuterProductTpl(); break; + case CellTpl: tpl = new TemplateCell(); break; + case RowAggTpl: tpl = new TemplateRowAgg(); break; + case OuterProdTpl: tpl = new TemplateOuterProduct(); break; } tpl._closed = closed; return tpl; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3120ce2/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index 101bad8..0d9bc1d 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -42,6 +42,7 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME4 = TEST_NAME+"4"; private static final String TEST_NAME5 = TEST_NAME+"5"; private static final String TEST_NAME6 = TEST_NAME+"6"; + private static final String TEST_NAME7 = TEST_NAME+"7"; private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -53,7 +54,7 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=6; i++) + for(int i=1; i<=7; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @@ -87,6 +88,11 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME6, true, ExecType.CP ); } + @Test + public void testCodegenRowAggRewrite7() { + testCodegenIntegration( TEST_NAME7, true, ExecType.CP ); + } + @Test public void testCodegenRowAgg1() { testCodegenIntegration( TEST_NAME1, false, ExecType.CP ); @@ -117,6 +123,11 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME6, false, ExecType.CP ); } + @Test + public void testCodegenRowAgg7() { + testCodegenIntegration( TEST_NAME7, false, ExecType.CP ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3120ce2/src/test/scripts/functions/codegen/rowAggPattern7.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern7.R b/src/test/scripts/functions/codegen/rowAggPattern7.R new file mode 100644 index 0000000..a343a8b --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern7.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(1,15), 5, 3, byrow=TRUE); +v = seq(1,3); +y = abs(seq(1,5)); + +r = X %*% v - y; +S = t(X) %*% r; +print(sum(r*r)); + +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3120ce2/src/test/scripts/functions/codegen/rowAggPattern7.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern7.dml b/src/test/scripts/functions/codegen/rowAggPattern7.dml new file mode 100644 index 0000000..2ad9fa6 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern7.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,15), rows=5, cols=3); +v = seq(1,3); +y = abs(seq(1,5)); + +r = X %*% v - y; +S = t(X) %*% r; +print(sum(r*r)); + +write(S,$1) +
