[SYSTEMML-2081] Support for nary cbind in codegen row operations This patch adds support for nary cbind operations such as cbind(X,Y,Z) in codegen row operations. In detail, this includes new tests, modified OFMC conditions of the row template, a new nary CPlan node, and additional runtime vector operations.
The support for nary cbind is an important precondition for effective operator fusion in many deep learning architectures, where parallel lanes are often consolidated via cbind, which allows to compile a fused operator with a single output for multiple lanes. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b02599c6 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b02599c6 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b02599c6 Branch: refs/heads/master Commit: b02599c6df454bdab8b8b9ebcd95eb4a4ee05325 Parents: 8cb5532 Author: Matthias Boehm <[email protected]> Authored: Thu Jan 25 17:39:59 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Jan 25 17:39:59 2018 -0800 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeNary.java | 145 +++++++++++++++++++ .../hops/codegen/template/TemplateRow.java | 33 ++++- .../runtime/codegen/LibSpoofPrimitives.java | 13 +- .../functions/codegen/RowAggTmplTest.java | 92 +++++++----- .../scripts/functions/codegen/rowAggPattern35.R | 33 +++++ .../functions/codegen/rowAggPattern35.dml | 30 ++++ 6 files changed, 301 insertions(+), 45 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java new file mode 100644 index 0000000..7f19194 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.ArrayList; + +import org.apache.sysml.hops.codegen.template.TemplateUtils; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.util.UtilFunctions; + +public class CNodeNary extends CNode +{ + public enum NaryType { + VECT_CBIND; + public static boolean contains(String value) { + for( NaryType bt : values() ) + if( bt.name().equals(value) ) + return true; + return false; + } + public String getTemplate(boolean sparseGen, long len, ArrayList<CNode> inputs) { + switch (this) { + case VECT_CBIND: + StringBuilder sb = new StringBuilder(); + sb.append(" double[] %TMP% = LibSpoofPrimitives.allocVector("+len+", true); //nary cbind\n"); + for( int i=0, off=0; i<inputs.size(); i++ ) { + CNode input = inputs.get(i); + boolean sparseInput = sparseGen && input instanceof CNodeData + && input.getVarname().startsWith("a"); + String varj = input.getVarname(); + String pos = (input instanceof CNodeData && input.getDataType().isMatrix()) ? + (!varj.startsWith("b")) ? varj+"i" : TemplateUtils.isMatrix(input) ? + varj + ".pos(rix)" : "0" : "0"; + sb.append( sparseInput ? + " LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, " + +varj+"ix, "+pos+", "+off+", "+input._cols+");\n" : + " LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj) + +", %TMP%, "+pos+", "+off+", "+input._cols+");\n"); + off += input._cols; + } + return sb.toString(); + default: + throw new RuntimeException("Invalid nary type: "+this.toString()); + } + } + public boolean isVectorPrimitive() { + return this == VECT_CBIND; + } + } + + private final NaryType _type; + + public CNodeNary( CNode[] inputs, NaryType type ) { + for( CNode in : inputs ) + _inputs.add(in); + _type = type; + setOutputDims(); + } + + public NaryType getType() { + return _type; + } + + @Override + public String codegen(boolean sparse) { + if( isGenerated() ) + return ""; + + StringBuilder sb = new StringBuilder(); + + //generate children + for(CNode in : _inputs) + sb.append(in.codegen(sparse)); + + //generate nary operation (use sparse template, if data input) + String var = createVarname(); + String tmp = _type.getTemplate(sparse, _cols, _inputs); + tmp = tmp.replace("%TMP%", var); + + sb.append(tmp); + + //mark as generated + _generated = true; + + return sb.toString(); + } + + @Override + public String toString() { + switch(_type) { + case VECT_CBIND: return "n(cbind)"; + default: + return "m("+_type.name().toLowerCase()+")"; + } + } + + @Override + public void setOutputDims() { + switch(_type) { + case VECT_CBIND: + _rows = _inputs.get(0)._rows; + _cols = 0; + for(CNode in : _inputs) + _cols += in._cols; + _dataType = DataType.MATRIX; + break; + } + } + + @Override + public int hashCode() { + if( _hash == 0 ) { + _hash = UtilFunctions.intHashCode( + super.hashCode(), _type.hashCode()); + } + return _hash; + } + + @Override + public boolean equals(Object o) { + if( !(o instanceof CNodeNary) ) + return false; + + CNodeNary that = (CNodeNary) o; + return super.equals(that) + && _type == that._type; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 8620971..5bc8f4f 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -39,6 +39,8 @@ import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeNary; +import org.apache.sysml.hops.codegen.cplan.CNodeNary.NaryType; import org.apache.sysml.hops.codegen.cplan.CNodeRow; import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; @@ -50,6 +52,7 @@ import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.Hop.OpOpN; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.matrix.data.LibMatrixMult; import org.apache.sysml.runtime.matrix.data.Pair; @@ -77,8 +80,8 @@ public class TemplateRow extends TemplateBase public boolean open(Hop hop) { return (hop instanceof BinaryOp && hop.dimsKnown() && isValidBinaryOperation(hop) && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) - || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() - && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1))) + || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) + || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) || (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || (hop instanceof AggBinaryOp && hop.dimsKnown() && LibMatrixMult.isSkinnyRightHandSide( @@ -101,9 +104,9 @@ public class TemplateRow extends TemplateBase @Override public boolean fuse(Hop hop, Hop input) { return !isClosed() && - ( (hop instanceof BinaryOp && isValidBinaryOperation(hop) ) - || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().indexOf(input)==0 - && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1))) + ( (hop instanceof BinaryOp && isValidBinaryOperation(hop)) + || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) + || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop)) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol @@ -125,9 +128,9 @@ public class TemplateRow extends TemplateBase return !isClosed() && ((hop instanceof BinaryOp && isValidBinaryOperation(hop) && hop.getDim1() > 1 && input.getDim1()>1) - || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() - && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1))) - ||(hop instanceof AggBinaryOp + || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) + || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) + || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) && (input.getDim2()==1 || (input==hop.getInput().get(1) && HopRewriteUtils.containsInput(input, hop.getInput().get(0).getInput().get(0)))))); @@ -432,6 +435,20 @@ public class TemplateRow extends TemplateBase out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString())); } + else if(HopRewriteUtils.isNary(hop, OpOpN.CBIND)) + { + CNode[] inputs = new CNode[hop.getInput().size()]; + for( int i=0; i<hop.getInput().size(); i++ ) { + Hop c = hop.getInput().get(i); + CNode cdata = tmp.get(c.getHopID()); + if( TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata) ) + cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c); + inputs[i] = cdata; + if( i==0 && cdata instanceof CNodeData && !inHops2.containsKey("X") ) + inHops2.put("X", c); + } + out = new CNodeNary(inputs, NaryType.VECT_CBIND); + } else if( hop instanceof ParameterizedBuiltinOp ) { CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID()); http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java index 5bef83b..8edb25e 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java @@ -206,12 +206,23 @@ public class LibSpoofPrimitives System.arraycopy(a, 0, c, ci, len); } + public static void vectWrite(double[] a, double[] c, int ai, int ci, int len) { + if( a == null ) return; + System.arraycopy(a, ai, c, ci, len); + } + public static void vectWrite(boolean[] a, boolean[] c, int[] aix) { if( a == null ) return; for( int i=0; i<aix.length; i++ ) c[aix[i]] = a[i]; } + public static void vectWrite(boolean[] a, boolean[] c, int[] aix, int ai, int ci, int alen) { + if( a == null ) return; + for( int i=ai; i<ai+alen; i++ ) + c[ci+aix[i]] = a[i]; + } + // cbind handling public static double[] vectCbindAdd(double[] a, double b, double[] c, int ai, int ci, int len) { @@ -1846,7 +1857,7 @@ public class LibSpoofPrimitives memPool.remove(); } - protected static double[] allocVector(int len, boolean reset) { + public static double[] allocVector(int len, boolean reset) { return allocVector(len, reset, 0); } http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index b5426ae..6124c77 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -34,7 +34,7 @@ import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; -public class RowAggTmplTest extends AutomatedTestBase +public class RowAggTmplTest extends AutomatedTestBase { private static final String TEST_NAME = "rowAggPattern"; private static final String TEST_NAME1 = TEST_NAME+"1"; //t(X)%*%(X%*%(lamda*v)) @@ -71,6 +71,7 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME32 = TEST_NAME+"32"; //X[, 1] - rowSums(X) private static final String TEST_NAME33 = TEST_NAME+"33"; //Kmeans, inner loop private static final String TEST_NAME34 = TEST_NAME+"34"; //X / rowSums(X!=0) + private static final String TEST_NAME35 = TEST_NAME+"35"; //cbind(X/rowSums(X), Y, Z) private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -82,11 +83,11 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=34; i++) + for(int i=1; i<=35; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } - @Test + @Test public void testCodegenRowAggRewrite1CP() { testCodegenIntegration( TEST_NAME1, true, ExecType.CP ); } @@ -101,7 +102,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME1, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite2CP() { testCodegenIntegration( TEST_NAME2, true, ExecType.CP ); } @@ -116,7 +117,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME2, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite3CP() { testCodegenIntegration( TEST_NAME3, true, ExecType.CP ); } @@ -131,7 +132,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME3, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite4CP() { testCodegenIntegration( TEST_NAME4, true, ExecType.CP ); } @@ -146,7 +147,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME4, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite5CP() { testCodegenIntegration( TEST_NAME5, true, ExecType.CP ); } @@ -161,7 +162,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME5, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite6CP() { testCodegenIntegration( TEST_NAME6, true, ExecType.CP ); } @@ -176,7 +177,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME6, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite7CP() { testCodegenIntegration( TEST_NAME7, true, ExecType.CP ); } @@ -191,7 +192,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME7, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite8CP() { testCodegenIntegration( TEST_NAME8, true, ExecType.CP ); } @@ -206,7 +207,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME8, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite9CP() { testCodegenIntegration( TEST_NAME9, true, ExecType.CP ); } @@ -221,7 +222,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME9, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite10CP() { testCodegenIntegration( TEST_NAME10, true, ExecType.CP ); } @@ -236,7 +237,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME10, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite11CP() { testCodegenIntegration( TEST_NAME11, true, ExecType.CP ); } @@ -251,7 +252,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME11, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite12CP() { testCodegenIntegration( TEST_NAME12, true, ExecType.CP ); } @@ -266,7 +267,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME12, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite13CP() { testCodegenIntegration( TEST_NAME13, true, ExecType.CP ); } @@ -281,7 +282,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME13, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite14CP() { testCodegenIntegration( TEST_NAME14, true, ExecType.CP ); } @@ -296,7 +297,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME14, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite15CP() { testCodegenIntegration( TEST_NAME15, true, ExecType.CP ); } @@ -311,7 +312,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME15, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite16CP() { testCodegenIntegration( TEST_NAME16, true, ExecType.CP ); } @@ -326,7 +327,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME16, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite17CP() { testCodegenIntegration( TEST_NAME17, true, ExecType.CP ); } @@ -341,7 +342,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME17, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite18CP() { testCodegenIntegration( TEST_NAME18, true, ExecType.CP ); } @@ -356,7 +357,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME18, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite19CP() { testCodegenIntegration( TEST_NAME19, true, ExecType.CP ); } @@ -371,7 +372,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME19, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite20CP() { testCodegenIntegration( TEST_NAME20, true, ExecType.CP ); } @@ -386,7 +387,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME20, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite21CP() { testCodegenIntegration( TEST_NAME21, true, ExecType.CP ); } @@ -401,7 +402,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME21, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite22CP() { testCodegenIntegration( TEST_NAME22, true, ExecType.CP ); } @@ -416,7 +417,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME22, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite23CP() { testCodegenIntegration( TEST_NAME23, true, ExecType.CP ); } @@ -431,7 +432,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME23, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite24CP() { testCodegenIntegration( TEST_NAME24, true, ExecType.CP ); } @@ -446,7 +447,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME24, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite25CP() { testCodegenIntegration( TEST_NAME25, true, ExecType.CP ); } @@ -461,7 +462,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME25, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite26CP() { testCodegenIntegration( TEST_NAME26, true, ExecType.CP ); } @@ -476,7 +477,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME26, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite27CP() { testCodegenIntegration( TEST_NAME27, true, ExecType.CP ); } @@ -491,7 +492,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME27, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite28CP() { testCodegenIntegration( TEST_NAME28, true, ExecType.CP ); } @@ -506,7 +507,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME28, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite29CP() { testCodegenIntegration( TEST_NAME29, true, ExecType.CP ); } @@ -521,7 +522,7 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME29, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenRowAggRewrite30CP() { testCodegenIntegration( TEST_NAME30, true, ExecType.CP ); } @@ -596,8 +597,24 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME34, false, ExecType.SPARK ); } + @Test + public void testCodegenRowAggRewrite35CP() { + testCodegenIntegration( TEST_NAME35, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg35CP() { + testCodegenIntegration( TEST_NAME35, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg35SP() { + testCodegenIntegration( TEST_NAME35, false, ExecType.SPARK ); + } + + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) - { + { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; RUNTIME_PLATFORM platformOld = rtplatform; switch( instType ) { @@ -620,7 +637,7 @@ public class RowAggTmplTest extends AutomatedTestBase programArgs = new String[]{"-explain", "recompile_runtime", "-stats", "-args", output("S") }; fullRScriptName = HOME + testname + ".R"; - rCmd = getRCmd(inputDir(), expectedDir()); + rCmd = getRCmd(inputDir(), expectedDir()); OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; @@ -632,7 +649,7 @@ public class RowAggTmplTest extends AutomatedTestBase HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); Assert.assertTrue(heavyHittersContainsSubString("spoofRA") - || heavyHittersContainsSubString("sp_spoofRA")); + || heavyHittersContainsSubString("sp_spoofRA")); //ensure full aggregates for certain patterns if( testname.equals(TEST_NAME15) ) @@ -647,6 +664,9 @@ public class RowAggTmplTest extends AutomatedTestBase && !heavyHittersContainsSubString(RightIndex.OPCODE)); if( testname.equals(TEST_NAME31) ) Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2)); + if( testname.equals(TEST_NAME35) ) + Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2) + && !heavyHittersContainsSubString("cbind")); } finally { rtplatform = platformOld; @@ -655,7 +675,7 @@ public class RowAggTmplTest extends AutomatedTestBase OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true; OptimizerUtils.ALLOW_OPERATOR_FUSION = true; } - } + } /** * Override default configuration with custom test configuration to ensure http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/test/scripts/functions/codegen/rowAggPattern35.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern35.R b/src/test/scripts/functions/codegen/rowAggPattern35.R new file mode 100644 index 0000000..d9bf8e9 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern35.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = matrix(seq(1,6000)/6000, 300, 20, byrow=TRUE); +Y = matrix(2, 300, 2); +Z = matrix(3, 300, 3); + +R = cbind(cbind(X/(rowSums(X)%*%matrix(1,1,ncol(X))), Y), Z); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/test/scripts/functions/codegen/rowAggPattern35.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern35.dml b/src/test/scripts/functions/codegen/rowAggPattern35.dml new file mode 100644 index 0000000..ee6d459 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern35.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X = matrix(seq(1,6000)/6000, 300, 20); +Y = matrix(2, 300, 2); +Z = matrix(3, 300, 3); +while(FALSE){} + +R = cbind(X/rowSums(X), Y, Z); + +write(R, $1)
