Repository: systemml
Updated Branches:
  refs/heads/master db13eac1a -> 50dafa038


[SYSTEMML-1678] Fix rewrite 'fuse axpy binary ops' for outer products

This patch fixes the dynamic simplification rewrite
fuseAxpyBinaryOperationChain to not trigger on outer products of vectors
and adds related negative test cases.
 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/50dafa03
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/50dafa03
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/50dafa03

Branch: refs/heads/master
Commit: 50dafa038ff3282f327260f2d413bdfd907bfe04
Parents: db13eac
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Mon Jun 26 22:33:23 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Mon Jun 26 22:33:23 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/ipa/FunctionCallSizeInfo.java    | 15 ++++++++++
 .../RewriteAlgebraicSimplificationDynamic.java  |  2 +-
 .../misc/RewriteFuseBinaryOpChainTest.java      | 24 +++++++++++++---
 .../misc/RewriteFuseBinaryOpChainTest4.R        | 30 ++++++++++++++++++++
 .../misc/RewriteFuseBinaryOpChainTest4.dml      | 29 +++++++++++++++++++
 5 files changed, 95 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java 
b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
index fb668b5..9f76e32 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
@@ -345,6 +345,21 @@ public class FunctionCallSizeInfo
                        sb.append("\n");
                }
                
+               sb.append("Valid #non-zeros for propagation: \n");
+               for( Entry<String, Set<Integer>> e : _fcandSafeNNZ.entrySet() ) 
{
+                       sb.append("--");
+                       sb.append(e.getKey());
+                       sb.append(": ");
+                       for( Integer pos : e.getValue() ) {
+                               sb.append(pos);
+                               sb.append(":");
+                               sb.append(_fgraph.getFunctionCalls(e.getKey())
+                                       .get(0).getInput().get(pos).getName());
+                               sb.append(" ");
+                       }
+                       sb.append("\n");
+               }
+               
                return sb.toString();
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 91c5972..9681e44 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2131,7 +2131,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
        private Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) 
        {
                //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X 
- s*Y -> X -* sY                
-               if( hi instanceof BinaryOp 
+               if( hi instanceof BinaryOp && !((BinaryOp) 
hi).isOuterVectorOperator()
                        && (((BinaryOp)hi).getOp()==OpOp2.PLUS || 
((BinaryOp)hi).getOp()==OpOp2.MINUS) )
                {
                        BinaryOp bop = (BinaryOp) hi;

http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
index 4c21587..f1d2a6a 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
@@ -32,7 +32,6 @@ import 
org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
-import org.apache.sysml.utils.Statistics;
 
 /**
  * Regression test for function recompile-once issue with literal replacement.
@@ -43,7 +42,8 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
        private static final String TEST_NAME1 = 
"RewriteFuseBinaryOpChainTest1"; //+* (X+s*Y)
        private static final String TEST_NAME2 = 
"RewriteFuseBinaryOpChainTest2"; //-* (X-s*Y) 
        private static final String TEST_NAME3 = 
"RewriteFuseBinaryOpChainTest3"; //+* (s*Y+X)
-
+       private static final String TEST_NAME4 = 
"RewriteFuseBinaryOpChainTest4"; //outer(X, s*Y, "+") not applied
+       
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/";
        
@@ -55,6 +55,7 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );          
  
        }
        
        @Test
@@ -147,6 +148,18 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                testFuseBinaryChain( TEST_NAME3, true, ExecType.MR );
        }
        
+       //negative tests
+       
+       @Test
+       public void testOuterBinaryPlusNoRewriteCP() {
+               testFuseBinaryChain( TEST_NAME4, false, ExecType.CP );
+       }
+       
+       @Test
+       public void testOuterBinaryPlusRewriteCP() {
+               testFuseBinaryChain( TEST_NAME4, true, ExecType.CP);
+       }
+       
        /**
         * 
         * @param testname
@@ -182,7 +195,7 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                        fullRScriptName = HOME + testname + ".R";
                        rCmd = getRCmd(inputDir(), expectedDir());              
        
 
-                       runTest(true, false, null, -1); 
+                       runTest(true, false, null, -1);
                        runRScript(true); 
                        
                        //compare matrices 
@@ -199,7 +212,10 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                                        prefix = Instruction.SP_INST_PREFIX;
                                
                                String opcode = 
(testname.equals(TEST_NAME1)||testname.equals(TEST_NAME3)) ? prefix+"+*" : 
prefix+"-*";
-                               Assert.assertTrue("Rewrite not 
applied.",Statistics.getCPHeavyHitterOpCodes().contains(opcode));
+                               if( testname.equals(TEST_NAME4) )
+                                       Assert.assertFalse("Rewrite applied.", 
heavyHittersContainsSubString(opcode));
+                               else
+                                       Assert.assertTrue("Rewrite not 
applied.", heavyHittersContainsSubString(opcode));
                        }
                }
                finally

http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.R 
b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.R
new file mode 100644
index 0000000..7e9a392
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.R
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(1, 10, 1);
+Y = matrix(2, 1, 10);
+lambda = 7;
+
+S = outer(as.vector(X), as.vector(lambda*Y), "+");
+
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.dml 
b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.dml
new file mode 100644
index 0000000..0599f02
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, 10, 1);
+Y = matrix(2, 1, 10);
+lambda = 7;
+if(1==1){}
+
+S = outer(X, lambda*Y, "+");
+
+write(S,$1);

Reply via email to