Repository: incubator-systemml
Updated Branches:
  refs/heads/master dea42de1f -> cf4e5ab6e


[SYSTEMML-765] New rewrite 'pushdown sum binary mult', tests

Closes #173


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/cf4e5ab6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/cf4e5ab6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/cf4e5ab6

Branch: refs/heads/master
Commit: cf4e5ab6e11273a8a468b669c22d76f55fc43f12
Parents: dea42de
Author: tgamal <[email protected]>
Authored: Wed Jun 8 19:22:01 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jun 8 19:22:01 2016 -0700

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   |  45 ++++++-
 .../misc/RewritePushdownSumBinaryMult.java      | 126 +++++++++++++++++++
 .../misc/RewritePushdownSumBinaryMult.R         |  23 ++++
 .../misc/RewritePushdownSumBinaryMult.dml       |  26 ++++
 .../misc/RewritePushdownSumBinaryMult2.R        |  24 ++++
 .../misc/RewritePushdownSumBinaryMult2.dml      |  26 ++++
 6 files changed, 266 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index fff9310..c36c01f 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -145,7 +145,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyBushyBinaryOperation(hop, hi, i);       
//e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
                        hi = simplifyUnaryAggReorgOperation(hop, hi, i);     
//e.g., sum(t(X)) -> sum(X)
                        hi = pushdownUnaryAggTransposeOperation(hop, hi, i); 
//e.g., colSums(t(X)) -> t(rowSums(X))
-                       hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
+                       hi = pushdownSumBinaryMult(hop, hi, i);                 
         //e.g., sum(lamda*X) -> lamda*sum(X)
+                       hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
                        hi = simplifyTransposedAppend(hop, hi, i);           
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
                        hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i);   
//e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X)
                        hi = simplifyTraceMatrixMult(hop, hi, i);            
//e.g., trace(X%*%Y)->sum(X*t(Y));  
@@ -161,8 +162,10 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = fuseLogNzBinaryOperation(hop, hi, i);           
//e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
                        hi = simplifyOuterSeqExpand(hop, hi, i);             
//e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, 
cast=false)
                        hi = simplifyTableSeqExpand(hop, hi, i);             
//e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, 
ignore=false, cast=true)
-                       //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
                        
+
+                       //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
+
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
                        if( !descendFirst )
                                rule_AlgebraicSimplification(hi, descendFirst);
@@ -928,7 +931,42 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                
                return hi;
        }
-       
+       /**
+        * 
+        * @param parent
+        * @param hi
+        * @param pos
+        * @return
+        * @throws HopsException
+        */
+       private Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) throws 
HopsException {
+               //pattern:  sum(lamda*X) -> lamda*sum(X)
+               if( hi instanceof AggUnaryOp && 
((AggUnaryOp)hi).getDirection()==Direction.RowCol
+                               && ((AggUnaryOp)hi).getOp()==Hop.AggOp.SUM
+                               && ((AggUnaryOp)hi).getInput().get(0) 
instanceof BinaryOp
+                               && 
((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.MULT
+                               && hi.getInput().get(0).getParent().size() == 1 
  // only one parent which is the sum
+                               && 
((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR && 
hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX)
+                                       
||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX && 
hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR)))
+               {
+                       Hop operand1 = hi.getInput().get(0).getInput().get(0); 
+                       Hop operand2 = hi.getInput().get(0).getInput().get(1);
+
+                       //check which operand is the Scalar and which is the 
matrix
+                       Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? 
operand1 : operand2; 
+                       Hop matrix = (operand1.getDataType()==DataType.MATRIX) 
? operand1 : operand2; 
+
+                       AggUnaryOp 
aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol);
+                       Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, 
OpOp2.MULT);
+                       
+                       HopRewriteUtils.removeChildReferenceByPos(parent, hi, 
pos);
+                       HopRewriteUtils.addChildReference(parent, bop, pos);
+                       
+                       LOG.debug("Applied pushdownSumBinaryMult.");
+                       return bop;
+               }
+               return hi;
+       }
        /**
         * 
         * @param parent
@@ -1870,5 +1908,4 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                
                return hi;
        }
-       
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumBinaryMult.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumBinaryMult.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumBinaryMult.java
new file mode 100644
index 0000000..9724d1d
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownSumBinaryMult.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * Regression test for function recompile-once issue with literal replacement.
+ * 
+ */
+public class RewritePushdownSumBinaryMult extends AutomatedTestBase 
+{
+       
+       private static final String TEST_NAME1 = "RewritePushdownSumBinaryMult";
+       private static final String TEST_NAME2 = 
"RewritePushdownSumBinaryMult2";
+
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewritePushdownSumBinaryMult.class.getSimpleName() + "/";
+       
+       //private static final int rows = 1234;
+       //private static final int cols = 567;
+       private static final double eps = Math.pow(10, -10);
+       
+       @Override
+       public void setUp() 
+       {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+       }
+       
+       @Test
+       public void testPushdownSumBinaryMultNoRewrite() 
+       {
+               testRewritePushdownSumBinaryMult( TEST_NAME1, false );
+       }
+       
+       
+       @Test
+       public void testPushdownSumBinaryMultRewrite() 
+       {
+               testRewritePushdownSumBinaryMult( TEST_NAME1, true );
+       }
+       
+       
+       @Test
+       public void testPushdownSumBinaryMultNoRewrite2() 
+       {
+               testRewritePushdownSumBinaryMult( TEST_NAME2, false );
+       }
+       
+       @Test
+       public void testPushdownSumBinaryMultRewrite2() 
+       {
+               testRewritePushdownSumBinaryMult( TEST_NAME2, true );
+       }
+       
+       
+       /**
+        * 
+        * @param condition
+        * @param branchRemoval
+        * @param IPA
+        */
+       private void testRewritePushdownSumBinaryMult( String testname, boolean 
rewrites )
+       {       
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try
+               {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ "-stats","-args", 
output("Scalar") };
+                       
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = getRCmd(inputDir(), expectedDir());              
        
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       runTest(true, false, null, -1); 
+                       runRScript(true); 
+                       
+                       //compare scalars 
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLScalarFromHDFS("Scalar");
+                       HashMap<CellIndex, Double> rfile  = 
readRScalarFromFS("Scalar");
+                       TestUtils.compareScalars(dmlfile.toString(), 
rfile.toString());
+                       System.out.println("Test case passed");
+                       
+               }
+               finally
+               {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+               
+       }       
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.R 
b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.R
new file mode 100644
index 0000000..48a000b
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.R
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X=matrix(1,10,10)
+lamda=sum(X)
+args<-commandArgs(TRUE)
+write(sum(lamda*X),paste(args[2],"Scalar",sep=""))

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.dml 
b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.dml
new file mode 100644
index 0000000..9850242
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X=matrix(1,10,10)
+if(1==1){}
+lamda=sum(X)
+y=sum(lamda*X)
+write(y, $1)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.R 
b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.R
new file mode 100644
index 0000000..09a0910
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.R
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X=matrix(1,10,10)
+lamda=sum(X)
+args<-commandArgs(TRUE)
+write(sum(X*lamda),paste(args[2],"Scalar",sep=""))

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cf4e5ab6/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.dml 
b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.dml
new file mode 100644
index 0000000..07e0e54
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewritePushdownSumBinaryMult2.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X=matrix(1,10,10)
+if(1==1){}
+lamda=sum(X)
+y=sum(X*lamda)
+write(y, $1)

Reply via email to