[SYSTEMML-2081] Support for nary cbind in codegen row operations

This patch adds support for nary cbind operations such as cbind(X,Y,Z)
in codegen row operations. In detail, this includes new tests, modified
OFMC conditions of the row template, a new nary CPlan node, and
additional runtime vector operations. 

The support for nary cbind is an important precondition for effective
operator fusion in many deep learning architectures, where parallel
lanes are often consolidated via cbind, which allows to compile a fused
operator with a single output for multiple lanes.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b02599c6
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b02599c6
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b02599c6

Branch: refs/heads/master
Commit: b02599c6df454bdab8b8b9ebcd95eb4a4ee05325
Parents: 8cb5532
Author: Matthias Boehm <[email protected]>
Authored: Thu Jan 25 17:39:59 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Thu Jan 25 17:39:59 2018 -0800

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeNary.java     | 145 +++++++++++++++++++
 .../hops/codegen/template/TemplateRow.java      |  33 ++++-
 .../runtime/codegen/LibSpoofPrimitives.java     |  13 +-
 .../functions/codegen/RowAggTmplTest.java       |  92 +++++++-----
 .../scripts/functions/codegen/rowAggPattern35.R |  33 +++++
 .../functions/codegen/rowAggPattern35.dml       |  30 ++++
 6 files changed, 301 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
new file mode 100644
index 0000000..7f19194
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.util.UtilFunctions;
+
+public class CNodeNary extends CNode
+{
+       public enum NaryType {
+               VECT_CBIND;
+               public static boolean contains(String value) {
+                       for( NaryType bt : values() )
+                               if( bt.name().equals(value) )
+                                       return true;
+                       return false;
+               }
+               public String getTemplate(boolean sparseGen, long len, 
ArrayList<CNode> inputs) {
+                       switch (this) {
+                               case VECT_CBIND:
+                                       StringBuilder sb = new StringBuilder();
+                                       sb.append("    double[] %TMP% = 
LibSpoofPrimitives.allocVector("+len+", true); //nary cbind\n");
+                                       for( int i=0, off=0; i<inputs.size(); 
i++ ) {
+                                               CNode input = inputs.get(i);
+                                               boolean sparseInput = sparseGen 
&& input instanceof CNodeData
+                                                       && 
input.getVarname().startsWith("a");
+                                               String varj = 
input.getVarname();
+                                               String pos = (input instanceof 
CNodeData && input.getDataType().isMatrix()) ? 
+                                                               
(!varj.startsWith("b")) ? varj+"i" : TemplateUtils.isMatrix(input) ? 
+                                                               varj + 
".pos(rix)" : "0" : "0";
+                                               sb.append( sparseInput ?
+                                                       "    
LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, "
+                                                               +varj+"ix, 
"+pos+", "+off+", "+input._cols+");\n" :
+                                                       "    
LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
+                                                               +", %TMP%, 
"+pos+", "+off+", "+input._cols+");\n");
+                                               off += input._cols;
+                                       }
+                                       return sb.toString();
+                               default:
+                                       throw new RuntimeException("Invalid 
nary type: "+this.toString());
+                       }
+               }
+               public boolean isVectorPrimitive() {
+                       return this == VECT_CBIND;
+               }
+       }
+       
+       private final NaryType _type;
+       
+       public CNodeNary( CNode[] inputs, NaryType type ) {
+               for( CNode in : inputs )
+                       _inputs.add(in);
+               _type = type;
+               setOutputDims();
+       }
+
+       public NaryType getType() {
+               return _type;
+       }
+       
+       @Override
+       public String codegen(boolean sparse) {
+               if( isGenerated() )
+                       return "";
+               
+               StringBuilder sb = new StringBuilder();
+               
+               //generate children
+               for(CNode in : _inputs)
+                       sb.append(in.codegen(sparse));
+               
+               //generate nary operation (use sparse template, if data input)
+               String var = createVarname();
+               String tmp = _type.getTemplate(sparse, _cols, _inputs);
+               tmp = tmp.replace("%TMP%", var);
+               
+               sb.append(tmp);
+               
+               //mark as generated
+               _generated = true;
+               
+               return sb.toString();
+       }
+       
+       @Override
+       public String toString() {
+               switch(_type) {
+                       case VECT_CBIND: return "n(cbind)";
+                       default:
+                               return "m("+_type.name().toLowerCase()+")";
+               }
+       }
+       
+       @Override
+       public void setOutputDims() {
+               switch(_type) {
+                       case VECT_CBIND:
+                               _rows = _inputs.get(0)._rows;
+                               _cols = 0;
+                               for(CNode in : _inputs)
+                                       _cols += in._cols;
+                               _dataType = DataType.MATRIX;
+                               break;
+               }
+       }
+       
+       @Override
+       public int hashCode() {
+               if( _hash == 0 ) {
+                       _hash = UtilFunctions.intHashCode(
+                               super.hashCode(), _type.hashCode());
+               }
+               return _hash;
+       }
+       
+       @Override 
+       public boolean equals(Object o) {
+               if( !(o instanceof CNodeNary) )
+                       return false;
+               
+               CNodeNary that = (CNodeNary) o;
+               return super.equals(that)
+                       && _type == that._type;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 8620971..5bc8f4f 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -39,6 +39,8 @@ import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
 import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
 import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeNary;
+import org.apache.sysml.hops.codegen.cplan.CNodeNary.NaryType;
 import org.apache.sysml.hops.codegen.cplan.CNodeRow;
 import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
 import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
@@ -50,6 +52,7 @@ import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOpN;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
 import org.apache.sysml.runtime.matrix.data.Pair;
@@ -77,8 +80,8 @@ public class TemplateRow extends TemplateBase
        public boolean open(Hop hop) {
                return (hop instanceof BinaryOp && hop.dimsKnown() && 
isValidBinaryOperation(hop)
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
-                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix()
-                               && hop.dimsKnown() && 
TemplateUtils.isColVector(hop.getInput().get(1)))
+                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+                       || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
hop.getDim2()==1 //MV
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
LibMatrixMult.isSkinnyRightHandSide(
@@ -101,9 +104,9 @@ public class TemplateRow extends TemplateBase
        @Override
        public boolean fuse(Hop hop, Hop input) {
                return !isClosed() && 
-                       (  (hop instanceof BinaryOp && 
isValidBinaryOperation(hop) ) 
-                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().indexOf(input)==0
-                               && hop.dimsKnown() && 
TemplateUtils.isColVector(hop.getInput().get(1)))
+                       (  (hop instanceof BinaryOp && 
isValidBinaryOperation(hop)) 
+                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+                       || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp) 
                                && TemplateCell.isValidOperation(hop))
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol
@@ -125,9 +128,9 @@ public class TemplateRow extends TemplateBase
                return !isClosed() &&
                        ((hop instanceof BinaryOp && isValidBinaryOperation(hop)
                                && hop.getDim1() > 1 && input.getDim1()>1)
-                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix()
-                               && hop.dimsKnown() && 
TemplateUtils.isColVector(hop.getInput().get(1)))
-                        ||(hop instanceof AggBinaryOp
+                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+                       || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+                       || (hop instanceof AggBinaryOp
                                && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))
                                && (input.getDim2()==1 || 
(input==hop.getInput().get(1) 
                                && HopRewriteUtils.containsInput(input, 
hop.getInput().get(0).getInput().get(0))))));
@@ -432,6 +435,20 @@ public class TemplateRow extends TemplateBase
                        out = new CNodeTernary(cdata1, cdata2, cdata3, 
                                TernaryType.valueOf(top.getOp().toString()));
                }
+               else if(HopRewriteUtils.isNary(hop, OpOpN.CBIND)) 
+               {
+                       CNode[] inputs = new CNode[hop.getInput().size()];
+                       for( int i=0; i<hop.getInput().size(); i++ ) {
+                               Hop c = hop.getInput().get(i);
+                               CNode cdata = tmp.get(c.getHopID());
+                               if( TemplateUtils.isColVector(cdata) || 
TemplateUtils.isRowVector(cdata) )
+                                       cdata = 
TemplateUtils.wrapLookupIfNecessary(cdata, c);
+                               inputs[i] = cdata;
+                               if( i==0 && cdata instanceof CNodeData && 
!inHops2.containsKey("X") )
+                                       inHops2.put("X", c);
+                       }
+                       out = new CNodeNary(inputs, NaryType.VECT_CBIND);
+               }
                else if( hop instanceof ParameterizedBuiltinOp )
                {
                        CNode cdata1 = 
tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());

http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java 
b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
index 5bef83b..8edb25e 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
@@ -206,12 +206,23 @@ public class LibSpoofPrimitives
                System.arraycopy(a, 0, c, ci, len);
        }
        
+       public static void vectWrite(double[] a, double[] c, int ai, int ci, 
int len) {
+               if( a == null ) return;
+               System.arraycopy(a, ai, c, ci, len);
+       }
+       
        public static void vectWrite(boolean[] a, boolean[] c, int[] aix) {
                if( a == null ) return;
                for( int i=0; i<aix.length; i++ )
                        c[aix[i]] = a[i];
        }
        
+       public static void vectWrite(boolean[] a, boolean[] c, int[] aix, int 
ai, int ci, int alen) {
+               if( a == null ) return;
+               for( int i=ai; i<ai+alen; i++ )
+                       c[ci+aix[i]] = a[i];
+       }
+       
        // cbind handling
        
        public static double[] vectCbindAdd(double[] a, double b, double[] c, 
int ai, int ci, int len) {
@@ -1846,7 +1857,7 @@ public class LibSpoofPrimitives
                memPool.remove();
        }
        
-       protected static double[] allocVector(int len, boolean reset) {
+       public static double[] allocVector(int len, boolean reset) {
                return allocVector(len, reset, 0);
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index b5426ae..6124c77 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -34,7 +34,7 @@ import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
 
-public class RowAggTmplTest extends AutomatedTestBase 
+public class RowAggTmplTest extends AutomatedTestBase
 {
        private static final String TEST_NAME = "rowAggPattern";
        private static final String TEST_NAME1 = TEST_NAME+"1"; 
//t(X)%*%(X%*%(lamda*v))
@@ -71,6 +71,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME32 = TEST_NAME+"32"; //X[, 1] - 
rowSums(X)
        private static final String TEST_NAME33 = TEST_NAME+"33"; //Kmeans, 
inner loop
        private static final String TEST_NAME34 = TEST_NAME+"34"; //X / 
rowSums(X!=0)
+       private static final String TEST_NAME35 = TEST_NAME+"35"; 
//cbind(X/rowSums(X), Y, Z)
        
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
@@ -82,11 +83,11 @@ public class RowAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<=34; i++)
+               for(int i=1; i<=35; i++)
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite1CP() {
                testCodegenIntegration( TEST_NAME1, true, ExecType.CP );
        }
@@ -101,7 +102,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME1, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite2CP() {
                testCodegenIntegration( TEST_NAME2, true, ExecType.CP );
        }
@@ -116,7 +117,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME2, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite3CP() {
                testCodegenIntegration( TEST_NAME3, true, ExecType.CP );
        }
@@ -131,7 +132,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME3, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite4CP() {
                testCodegenIntegration( TEST_NAME4, true, ExecType.CP );
        }
@@ -146,7 +147,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME4, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite5CP() {
                testCodegenIntegration( TEST_NAME5, true, ExecType.CP );
        }
@@ -161,7 +162,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME5, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite6CP() {
                testCodegenIntegration( TEST_NAME6, true, ExecType.CP );
        }
@@ -176,7 +177,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME6, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite7CP() {
                testCodegenIntegration( TEST_NAME7, true, ExecType.CP );
        }
@@ -191,7 +192,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME7, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite8CP() {
                testCodegenIntegration( TEST_NAME8, true, ExecType.CP );
        }
@@ -206,7 +207,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME8, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite9CP() {
                testCodegenIntegration( TEST_NAME9, true, ExecType.CP );
        }
@@ -221,7 +222,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME9, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite10CP() {
                testCodegenIntegration( TEST_NAME10, true, ExecType.CP );
        }
@@ -236,7 +237,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME10, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite11CP() {
                testCodegenIntegration( TEST_NAME11, true, ExecType.CP );
        }
@@ -251,7 +252,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME11, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite12CP() {
                testCodegenIntegration( TEST_NAME12, true, ExecType.CP );
        }
@@ -266,7 +267,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME12, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite13CP() {
                testCodegenIntegration( TEST_NAME13, true, ExecType.CP );
        }
@@ -281,7 +282,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME13, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite14CP() {
                testCodegenIntegration( TEST_NAME14, true, ExecType.CP );
        }
@@ -296,7 +297,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME14, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite15CP() {
                testCodegenIntegration( TEST_NAME15, true, ExecType.CP );
        }
@@ -311,7 +312,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME15, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite16CP() {
                testCodegenIntegration( TEST_NAME16, true, ExecType.CP );
        }
@@ -326,7 +327,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME16, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite17CP() {
                testCodegenIntegration( TEST_NAME17, true, ExecType.CP );
        }
@@ -341,7 +342,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME17, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite18CP() {
                testCodegenIntegration( TEST_NAME18, true, ExecType.CP );
        }
@@ -356,7 +357,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME18, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite19CP() {
                testCodegenIntegration( TEST_NAME19, true, ExecType.CP );
        }
@@ -371,7 +372,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME19, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite20CP() {
                testCodegenIntegration( TEST_NAME20, true, ExecType.CP );
        }
@@ -386,7 +387,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME20, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite21CP() {
                testCodegenIntegration( TEST_NAME21, true, ExecType.CP );
        }
@@ -401,7 +402,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME21, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite22CP() {
                testCodegenIntegration( TEST_NAME22, true, ExecType.CP );
        }
@@ -416,7 +417,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME22, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite23CP() {
                testCodegenIntegration( TEST_NAME23, true, ExecType.CP );
        }
@@ -431,7 +432,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME23, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite24CP() {
                testCodegenIntegration( TEST_NAME24, true, ExecType.CP );
        }
@@ -446,7 +447,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME24, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite25CP() {
                testCodegenIntegration( TEST_NAME25, true, ExecType.CP );
        }
@@ -461,7 +462,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME25, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite26CP() {
                testCodegenIntegration( TEST_NAME26, true, ExecType.CP );
        }
@@ -476,7 +477,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME26, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite27CP() {
                testCodegenIntegration( TEST_NAME27, true, ExecType.CP );
        }
@@ -491,7 +492,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME27, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite28CP() {
                testCodegenIntegration( TEST_NAME28, true, ExecType.CP );
        }
@@ -506,7 +507,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME28, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite29CP() {
                testCodegenIntegration( TEST_NAME29, true, ExecType.CP );
        }
@@ -521,7 +522,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME29, false, ExecType.SPARK );
        }
        
-       @Test   
+       @Test
        public void testCodegenRowAggRewrite30CP() {
                testCodegenIntegration( TEST_NAME30, true, ExecType.CP );
        }
@@ -596,8 +597,24 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME34, false, ExecType.SPARK );
        }
        
+       @Test
+       public void testCodegenRowAggRewrite35CP() {
+               testCodegenIntegration( TEST_NAME35, true, ExecType.CP );
+       }
+       
+       @Test
+       public void testCodegenRowAgg35CP() {
+               testCodegenIntegration( TEST_NAME35, false, ExecType.CP );
+       }
+       
+       @Test
+       public void testCodegenRowAgg35SP() {
+               testCodegenIntegration( TEST_NAME35, false, ExecType.SPARK );
+       }
+       
+       
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
-       {       
+       {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                RUNTIME_PLATFORM platformOld = rtplatform;
                switch( instType ) {
@@ -620,7 +637,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                        programArgs = new String[]{"-explain", 
"recompile_runtime", "-stats", "-args", output("S") };
                        
                        fullRScriptName = HOME + testname + ".R";
-                       rCmd = getRCmd(inputDir(), expectedDir());              
        
+                       rCmd = getRCmd(inputDir(), expectedDir());
 
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
                        
@@ -632,7 +649,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("S");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                        
Assert.assertTrue(heavyHittersContainsSubString("spoofRA") 
-                                       || 
heavyHittersContainsSubString("sp_spoofRA"));
+                               || heavyHittersContainsSubString("sp_spoofRA"));
                        
                        //ensure full aggregates for certain patterns
                        if( testname.equals(TEST_NAME15) )
@@ -647,6 +664,9 @@ public class RowAggTmplTest extends AutomatedTestBase
                                        && 
!heavyHittersContainsSubString(RightIndex.OPCODE));
                        if( testname.equals(TEST_NAME31) )
                                
Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2));
+                       if( testname.equals(TEST_NAME35) )
+                               
Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2)
+                                       && 
!heavyHittersContainsSubString("cbind"));
                }
                finally {
                        rtplatform = platformOld;
@@ -655,7 +675,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                        OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
                        OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
                }
-       }       
+       }
 
        /**
         * Override default configuration with custom test configuration to 
ensure

http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/test/scripts/functions/codegen/rowAggPattern35.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern35.R 
b/src/test/scripts/functions/codegen/rowAggPattern35.R
new file mode 100644
index 0000000..d9bf8e9
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern35.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = matrix(seq(1,6000)/6000, 300, 20, byrow=TRUE);
+Y = matrix(2, 300, 2);
+Z = matrix(3, 300, 3);
+
+R = cbind(cbind(X/(rowSums(X)%*%matrix(1,1,ncol(X))), Y), Z);
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/b02599c6/src/test/scripts/functions/codegen/rowAggPattern35.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern35.dml 
b/src/test/scripts/functions/codegen/rowAggPattern35.dml
new file mode 100644
index 0000000..ee6d459
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern35.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = matrix(seq(1,6000)/6000, 300, 20);
+Y = matrix(2, 300, 2);
+Z = matrix(3, 300, 3);
+while(FALSE){}
+
+R = cbind(X/rowSums(X), Y, Z);
+
+write(R, $1)

Reply via email to