TernaryAggregate now applies to a power of 3.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f005d949
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f005d949
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f005d949

Branch: refs/heads/master
Commit: f005d94997d9c17ad8e90b4d2bd340f81b9a752d
Parents: 8b832f6
Author: Dylan Hutchison <[email protected]>
Authored: Fri Jun 9 22:06:10 2017 -0700
Committer: Dylan Hutchison <[email protected]>
Committed: Sun Jun 18 17:43:24 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  | 67 ++++++++++++--------
 .../functions/misc/RewriteEMultChainTest.java   |  7 +-
 .../functions/misc/RewriteEMultChainOp.R        | 33 ----------
 .../functions/misc/RewriteEMultChainOp.dml      | 28 --------
 .../functions/misc/RewriteEMultChainOpXYX.R     | 33 ++++++++++
 .../functions/misc/RewriteEMultChainOpXYX.dml   | 28 ++++++++
 6 files changed, 106 insertions(+), 90 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 4573b66..300a20c 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -490,29 +490,35 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                        (_direction == Direction.RowCol || _direction == 
Direction.Col)  ) 
                {
                        Hop input1 = getInput().get(0);
-                       if( input1.getParent().size() == 1 && //sum single 
consumer
-                               input1 instanceof BinaryOp && 
((BinaryOp)input1).getOp()==OpOp2.MULT
-                               // As unary agg instruction is not implemented 
in MR and since MR is in maintenance mode, postponed it.
-                               && input1.optFindExecType() != ExecType.MR) 
-                       {
-                               Hop input11 = input1.getInput().get(0);
-                               Hop input12 = input1.getInput().get(1);
-                               
-                               if( input11 instanceof BinaryOp && 
((BinaryOp)input11).getOp()==OpOp2.MULT ) {
-                                       //ternary, arbitrary matrices but no 
mv/outer operations.
-                                       ret = 
HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1)
-                                               && 
HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1)       
-                                               && 
HopRewriteUtils.isEqualSize(input12, input1);
-                               }
-                               else if( input12 instanceof BinaryOp && 
((BinaryOp)input12).getOp()==OpOp2.MULT ) {
-                                       //ternary, arbitrary matrices but no 
mv/outer operations.
-                                       ret = 
HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1)
-                                                       && 
HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1)       
-                                                       && 
HopRewriteUtils.isEqualSize(input11, input1);
+                       if (input1.getParent().size() == 1
+                                       && input1 instanceof BinaryOp) { //sum 
single consumer
+                               BinaryOp binput1 = (BinaryOp)input1;
+
+                               if (binput1.getOp() == OpOp2.POW
+                                               && binput1.getInput().get(1) 
instanceof LiteralOp) {
+                                       LiteralOp lit = 
(LiteralOp)binput1.getInput().get(1);
+                                       ret = lit.getLongValue() == 3;
                                }
-                               else {
-                                       //binary, arbitrary matrices but no 
mv/outer operations.
-                                       ret = 
HopRewriteUtils.isEqualSize(input11, input12);
+                               else if (binput1.getOp() == OpOp2.MULT
+                                               // As unary agg instruction is 
not implemented in MR and since MR is in maintenance mode, postponed it.
+                                               && input1.optFindExecType() != 
ExecType.MR) {
+                                       Hop input11 = input1.getInput().get(0);
+                                       Hop input12 = input1.getInput().get(1);
+
+                                       if (input11 instanceof BinaryOp && 
((BinaryOp) input11).getOp() == OpOp2.MULT) {
+                                               //ternary, arbitrary matrices 
but no mv/outer operations.
+                                               ret = 
HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && 
HopRewriteUtils
+                                                               
.isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils
+                                                               
.isEqualSize(input12, input1);
+                                       } else if (input12 instanceof BinaryOp 
&& ((BinaryOp) input12).getOp() == OpOp2.MULT) {
+                                               //ternary, arbitrary matrices 
but no mv/outer operations.
+                                               ret = 
HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && 
HopRewriteUtils
+                                                               
.isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils
+                                                               
.isEqualSize(input11, input1);
+                                       } else {
+                                               //binary, arbitrary matrices 
but no mv/outer operations.
+                                               ret = 
HopRewriteUtils.isEqualSize(input11, input12);
+                                       }
                                }
                        }
                }
@@ -626,14 +632,25 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
        private Lop constructLopsTernaryAggregateRewrite(ExecType et) 
                throws HopsException, LopsException
        {
-               Hop input1 = getInput().get(0);
+               BinaryOp input1 = (BinaryOp)getInput().get(0);
                Hop input11 = input1.getInput().get(0);
                Hop input12 = input1.getInput().get(1);
                
                Lop in1 = null, in2 = null, in3 = null;
                boolean handled = false;
-               
-               if( input11 instanceof BinaryOp ) {
+
+               if (input1.getOp() == OpOp2.POW) {
+                       switch ((int)((LiteralOp)input12).getLongValue()) {
+                       case 3:
+                               in1 = input11.constructLops();
+                               in2 = in1;
+                               in3 = in1;
+                               break;
+                       default:
+                               throw new AssertionError("unreachable; only 
applies to power 3");
+                       }
+                       handled = true;
+               } else if (input11 instanceof BinaryOp ) {
                        BinaryOp b11 = (BinaryOp)input11;
                        switch (b11.getOp()) {
                        case MULT: // A*B*C case

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index 18ed55d..85dbea4 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -37,7 +37,7 @@ import org.junit.Test;
  */
 public class RewriteEMultChainTest extends AutomatedTestBase
 {
-       private static final String TEST_NAME1 = "RewriteEMultChainOp";
+       private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteEMultChainTest.class.getSimpleName() + "/";
        
@@ -94,8 +94,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[]{ "-explain", "hops", 
"-stats", 
-                               "-args", input("X"), input("Y"), output("R") };
+                       programArgs = new String[] { "-explain", "hops", 
"-stats", "-args", input("X"), input("Y"), output("R") };
                        fullRScriptName = HOME + testname + ".R";
                        rCmd = getRCmd(inputDir(), expectedDir());              
        
 
@@ -104,7 +103,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
                        double[][] Y = getRandomMatrix(rows, cols, -1, 1, 
Ysparsity, 3);
                        writeInputMatrixWithMTD("X", X, true);
                        writeInputMatrixWithMTD("Y", Y, true);
-                       
+
                        //execute tests
                        runTest(true, false, null, -1); 
                        runRScript(true); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R 
b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
deleted file mode 100644
index 6d94cc8..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.R
+++ /dev/null
@@ -1,33 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-args <- commandArgs(TRUE)
-options(digits=22)
-library("Matrix")
-library("matrixStats")
-
-X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
-Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
-
-R = X * Y * X;
-
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml 
b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
deleted file mode 100644
index 3992403..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
+++ /dev/null
@@ -1,28 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-X = read($1);
-Y = read($2);
-
-R = X * Y * X;
-
-write(R, $3);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R 
b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
new file mode 100644
index 0000000..6d94cc8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+
+R = X * Y * X;
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml 
b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
new file mode 100644
index 0000000..3992403
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+
+R = X * Y * X;
+
+write(R, $3);
\ No newline at end of file

Reply via email to