Repository: systemml
Updated Branches:
  refs/heads/master 5f580f02e -> 4d370a8a6


[SYSTEMML-2374] New simplification rewrite 'fold nary min/max ops'

This patch adds a new dynamic rewrite for folding nested binary or nary
min/max operations into a single nary min/max operation. Due to limited
support for broadcasting this is a dynamic rewrite that is only applied
if the dimensions of all involved matrix inputs match.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8d320791
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8d320791
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8d320791

Branch: refs/heads/master
Commit: 8d320791265321de38050e741308e3243ce89a7b
Parents: 5f580f0
Author: Matthias Boehm <[email protected]>
Authored: Thu Jun 7 22:23:05 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Jun 7 22:23:05 2018 -0700

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationDynamic.java  |  57 ++++++++-
 .../functions/misc/RewriteFoldMinMaxTest.java   | 118 +++++++++++++++++++
 .../scripts/functions/misc/RewriteFoldMax.dml   |  28 +++++
 .../scripts/functions/misc/RewriteFoldMin.dml   |  28 +++++
 .../functions/misc/ZPackageSuite.java           |   1 +
 5 files changed, 231 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 81c20e0..062da2f 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -22,6 +22,7 @@ package org.apache.sysml.hops.rewrite;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
+import java.util.List;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -37,11 +38,13 @@ import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.OpOp4;
+import org.apache.sysml.hops.Hop.OpOpN;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LeftIndexingOp;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.NaryOp;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.QuaternaryOp;
@@ -182,7 +185,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                                hi = simplifyWeightedUnaryMM(hop, hi, i);       
  //e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp)
                                hi = simplifyDotProductSum(hop, hi, i);         
  //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1 
                                hi = fuseSumSquared(hop, hi, i);                
  //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
-                               hi = fuseAxpyBinaryOperationChain(hop, hi, i);  
  //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)      
+                               hi = fuseAxpyBinaryOperationChain(hop, hi, i);  
  //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)
                        }
                        hi = reorderMinusMatrixMult(hop, hi, i);          
//e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
                        hi = simplifySumMatrixMult(hop, hi, i);           
//e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
@@ -191,6 +194,8 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                        hi = simplifyNnzComputation(hop, hi, i);          
//e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
                        hi = simplifyNrowNcolComputation(hop, hi, i);     
//e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency
                        hi = simplifyTableSeqExpand(hop, hi, i);          
//e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, 
ignore=false, cast=true)
+                       if( OptimizerUtils.ALLOW_OPERATOR_FUSION )
+                               foldMultipleMinMaxOperations(hi);             
//e.g., min(X,min(min(3,7),Y)) -> min(X,3,7,Y)
                        
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
                        if( !descendFirst )
@@ -2584,4 +2589,54 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
        
                return hi;
        }
+       
+       private static Hop foldMultipleMinMaxOperations(Hop hi) 
+       {
+               if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX) 
+                       || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX))
+                       && !OptimizerUtils.isHadoopExecutionMode() )
+               {
+                       OpOp2 bop = (hi instanceof BinaryOp) ? 
((BinaryOp)hi).getOp() :
+                               OpOp2.valueOf(((NaryOp)hi).getOp().name());
+                       OpOpN nop = (hi instanceof NaryOp) ? 
((NaryOp)hi).getOp() :
+                               OpOpN.valueOf(((BinaryOp)hi).getOp().name());
+                       
+                       boolean converged = false;
+                       while( !converged ) {
+                               //get first matching min/max
+                               Hop first = hi.getInput().stream()
+                                       .filter(h -> 
HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop))
+                                       .findFirst().orElse(null);
+                               
+                               //replace current op with new nary min/max
+                               final Hop lhi = hi;
+                               if( first != null && first.getParent().size()==1
+                                       && first.getInput().stream().allMatch(c 
-> c.getDataType()==DataType.SCALAR 
+                                               || 
HopRewriteUtils.isEqualSize(lhi, c))) {
+                                       //construct new list of inputs (in 
original order)
+                                       ArrayList<Hop> linputs = new 
ArrayList<>();
+                                       for(Hop in : hi.getInput())
+                                               if( in == first )
+                                                       
linputs.addAll(first.getInput());
+                                               else
+                                                       linputs.add(in);
+                                       Hop hnew = 
HopRewriteUtils.createNary(nop, linputs.toArray(new Hop[0]));
+                                       //clear dangling references
+                                       
HopRewriteUtils.removeAllChildReferences(hi);
+                                       
HopRewriteUtils.removeAllChildReferences(first);
+                                       //rewire all parents (avoid anomalies 
with refs to hi)
+                                       List<Hop> parents = new 
ArrayList<>(hi.getParent());
+                                       for( Hop p : parents )
+                                               
HopRewriteUtils.replaceChildReference(p, hi, hnew);
+                                       hi = hnew;
+                                       LOG.debug("Applied 
foldMultipleMinMaxOperations (line "+hi.getBeginLine()+").");
+                               }
+                               else {
+                                       converged = true;
+                               }
+                       }
+               }
+               
+               return hi;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldMinMaxTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldMinMaxTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldMinMaxTest.java
new file mode 100644
index 0000000..65c2a3e
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldMinMaxTest.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+public class RewriteFoldMinMaxTest extends AutomatedTestBase 
+{
+       private static final String TEST_NAME1 = "RewriteFoldMin";
+       private static final String TEST_NAME2 = "RewriteFoldMax";
+       
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteFoldMinMaxTest.class.getSimpleName() + "/";
+       
+       private static final int rows = 1932;
+       private static final int cols = 14;
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testRewriteFoldMinNoRewrite() {
+               testRewriteFoldMinMax( TEST_NAME1, false, ExecType.CP );
+       }
+       
+       @Test
+       public void testRewriteFoldMinRewrite() {
+               testRewriteFoldMinMax( TEST_NAME1, true, ExecType.CP );
+       }
+       
+       @Test
+       public void testRewriteFoldMaxNoRewrite() {
+               testRewriteFoldMinMax( TEST_NAME2, false, ExecType.CP );
+       }
+       
+       @Test
+       public void testRewriteFoldMaxRewrite() {
+               testRewriteFoldMinMax( TEST_NAME2, true, ExecType.CP );
+       }
+
+       private void testRewriteFoldMinMax( String testname, boolean rewrites, 
ExecType et )
+       {
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               switch( et ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; 
break;
+               }
+               
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == 
RUNTIME_PLATFORM.HYBRID_SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ "-stats", "-args", 
String.valueOf(rows), 
+                                       String.valueOf(cols), output("R") };
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       //run performance tests
+                       runTest(true, false, null, -1); 
+                       
+                       //compare matrices 
+                       Double ret = readDMLMatrixFromHDFS("R").get(new 
CellIndex(1,1));
+                       Assert.assertEquals("Wrong result", new 
Double(5*rows*cols), ret);
+                       
+                       //check for applied rewrites
+                       if( rewrites ) {
+                               
Assert.assertTrue(!heavyHittersContainsString("min") && 
!heavyHittersContainsString("max")
+                                       && (!testname.equals(TEST_NAME1) || 
Statistics.getCPHeavyHitterCount("nmin") == 1)
+                                       && (!testname.equals(TEST_NAME2) || 
Statistics.getCPHeavyHitterCount("nmax") == 1));
+                       }
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       rtplatform = platformOld;
+               }
+       }
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/test/scripts/functions/misc/RewriteFoldMax.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFoldMax.dml 
b/src/test/scripts/functions/misc/RewriteFoldMax.dml
new file mode 100644
index 0000000..c5117c8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFoldMax.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, $1, $2)
+while(FALSE){}
+Y = max(X-7,max(max(X-5,-7),5))
+while(FALSE){}
+R = as.matrix(sum(Y))
+
+write(R, $3);

http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/test/scripts/functions/misc/RewriteFoldMin.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFoldMin.dml 
b/src/test/scripts/functions/misc/RewriteFoldMin.dml
new file mode 100644
index 0000000..7919d9b
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFoldMin.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, $1, $2)
+while(FALSE){}
+Y = min(X+7,min(min(X+5,7),5))
+while(FALSE){}
+R = as.matrix(sum(Y))
+
+write(R, $3);

http://git-wip-us.apache.org/repos/asf/systemml/blob/8d320791/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index e2c7bf1..75e9970 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -61,6 +61,7 @@ import org.junit.runners.Suite;
        RewriteCTableToRExpandTest.class,
        RewriteElementwiseMultChainOptimizationTest.class,
        RewriteEliminateAggregatesTest.class,
+       RewriteFoldMinMaxTest.class,
        RewriteFoldRCBindTest.class,
        RewriteFuseBinaryOpChainTest.class,
        RewriteFusedRandTest.class,

Reply via email to