TernaryAggregate now applies to a power of 3.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f005d949 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f005d949 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f005d949 Branch: refs/heads/master Commit: f005d94997d9c17ad8e90b4d2bd340f81b9a752d Parents: 8b832f6 Author: Dylan Hutchison <[email protected]> Authored: Fri Jun 9 22:06:10 2017 -0700 Committer: Dylan Hutchison <[email protected]> Committed: Sun Jun 18 17:43:24 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 67 ++++++++++++-------- .../functions/misc/RewriteEMultChainTest.java | 7 +- .../functions/misc/RewriteEMultChainOp.R | 33 ---------- .../functions/misc/RewriteEMultChainOp.dml | 28 -------- .../functions/misc/RewriteEMultChainOpXYX.R | 33 ++++++++++ .../functions/misc/RewriteEMultChainOpXYX.dml | 28 ++++++++ 6 files changed, 106 insertions(+), 90 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index 4573b66..300a20c 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -490,29 +490,35 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop (_direction == Direction.RowCol || _direction == Direction.Col) ) { Hop input1 = getInput().get(0); - if( input1.getParent().size() == 1 && //sum single consumer - input1 instanceof BinaryOp && ((BinaryOp)input1).getOp()==OpOp2.MULT - // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it. - && input1.optFindExecType() != ExecType.MR) - { - Hop input11 = input1.getInput().get(0); - Hop input12 = input1.getInput().get(1); - - if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) { - //ternary, arbitrary matrices but no mv/outer operations. - ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) - && HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1) - && HopRewriteUtils.isEqualSize(input12, input1); - } - else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) { - //ternary, arbitrary matrices but no mv/outer operations. - ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) - && HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1) - && HopRewriteUtils.isEqualSize(input11, input1); + if (input1.getParent().size() == 1 + && input1 instanceof BinaryOp) { //sum single consumer + BinaryOp binput1 = (BinaryOp)input1; + + if (binput1.getOp() == OpOp2.POW + && binput1.getInput().get(1) instanceof LiteralOp) { + LiteralOp lit = (LiteralOp)binput1.getInput().get(1); + ret = lit.getLongValue() == 3; } - else { - //binary, arbitrary matrices but no mv/outer operations. - ret = HopRewriteUtils.isEqualSize(input11, input12); + else if (binput1.getOp() == OpOp2.MULT + // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it. + && input1.optFindExecType() != ExecType.MR) { + Hop input11 = input1.getInput().get(0); + Hop input12 = input1.getInput().get(1); + + if (input11 instanceof BinaryOp && ((BinaryOp) input11).getOp() == OpOp2.MULT) { + //ternary, arbitrary matrices but no mv/outer operations. + ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils + .isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils + .isEqualSize(input12, input1); + } else if (input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) { + //ternary, arbitrary matrices but no mv/outer operations. + ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils + .isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils + .isEqualSize(input11, input1); + } else { + //binary, arbitrary matrices but no mv/outer operations. + ret = HopRewriteUtils.isEqualSize(input11, input12); + } } } } @@ -626,14 +632,25 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop private Lop constructLopsTernaryAggregateRewrite(ExecType et) throws HopsException, LopsException { - Hop input1 = getInput().get(0); + BinaryOp input1 = (BinaryOp)getInput().get(0); Hop input11 = input1.getInput().get(0); Hop input12 = input1.getInput().get(1); Lop in1 = null, in2 = null, in3 = null; boolean handled = false; - - if( input11 instanceof BinaryOp ) { + + if (input1.getOp() == OpOp2.POW) { + switch ((int)((LiteralOp)input12).getLongValue()) { + case 3: + in1 = input11.constructLops(); + in2 = in1; + in3 = in1; + break; + default: + throw new AssertionError("unreachable; only applies to power 3"); + } + handled = true; + } else if (input11 instanceof BinaryOp ) { BinaryOp b11 = (BinaryOp)input11; switch (b11.getOp()) { case MULT: // A*B*C case http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java index 18ed55d..85dbea4 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java @@ -37,7 +37,7 @@ import org.junit.Test; */ public class RewriteEMultChainTest extends AutomatedTestBase { - private static final String TEST_NAME1 = "RewriteEMultChainOp"; + private static final String TEST_NAME1 = "RewriteEMultChainOpXYX"; private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/"; @@ -94,8 +94,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[]{ "-explain", "hops", "-stats", - "-args", input("X"), input("Y"), output("R") }; + programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") }; fullRScriptName = HOME + testname + ".R"; rCmd = getRCmd(inputDir(), expectedDir()); @@ -104,7 +103,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3); writeInputMatrixWithMTD("X", X, true); writeInputMatrixWithMTD("Y", Y, true); - + //execute tests runTest(true, false, null, -1); runRScript(true); http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R b/src/test/scripts/functions/misc/RewriteEMultChainOp.R deleted file mode 100644 index 6d94cc8..0000000 --- a/src/test/scripts/functions/misc/RewriteEMultChainOp.R +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - - -args <- commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") - -X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) -Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) - -R = X * Y * X; - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml deleted file mode 100644 index 3992403..0000000 --- a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml +++ /dev/null @@ -1,28 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - - -X = read($1); -Y = read($2); - -R = X * Y * X; - -write(R, $3); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R new file mode 100644 index 0000000..6d94cc8 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) + +R = X * Y * X; + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml new file mode 100644 index 0000000..3992403 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X = read($1); +Y = read($2); + +R = X * Y * X; + +write(R, $3); \ No newline at end of file
