Repository: systemml Updated Branches: refs/heads/master 573943e0e -> c89d3be80
[SYSTEMML-2454] Fix codegen binary outer operation handling So far we generated invalid codegen plans for binary outer vector operations leading to incorrect results. This patch effectively disables such outer vector operations (which anyway have dedicated physical operators that change their asymptotic behavior) in all codegen templates. Furthermore, this also includes related tests. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c89d3be8 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c89d3be8 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c89d3be8 Branch: refs/heads/master Commit: c89d3be80c13d47c2545840ce5b33e7debec60a5 Parents: 573943e Author: Matthias Boehm <[email protected]> Authored: Wed Jul 18 18:23:51 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jul 18 18:23:51 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 11 +++---- .../java/org/apache/sysml/hops/BinaryOp.java | 2 +- .../hops/codegen/template/TemplateCell.java | 15 ++++------ .../hops/codegen/template/TemplateUtils.java | 2 +- .../ipa/IPAPassRemoveConstantBinaryOps.java | 4 +-- .../RewriteAlgebraicSimplificationDynamic.java | 2 +- .../RewriteAlgebraicSimplificationStatic.java | 4 +-- .../functions/codegen/CellwiseTmplTest.java | 19 ++++++++++-- .../scripts/functions/codegen/cellwisetmpl27.R | 31 ++++++++++++++++++++ .../functions/codegen/cellwisetmpl27.dml | 24 +++++++++++++++ 10 files changed, 89 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index 4e6cf95..47943d0 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -555,7 +555,7 @@ public class AggUnaryOp extends MultiThreadedHop boolean ret = false; Hop input = getInput().get(0); - if( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() ) + if( input instanceof BinaryOp && ((BinaryOp)input).isOuter() ) { //for special cases, we need to hold the broadcast twice in order to allow for //an efficient binary search over a plain java array @@ -592,7 +592,7 @@ public class AggUnaryOp extends MultiThreadedHop boolean ret = false; Hop input = getInput().get(0); - if( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() ) + if( input instanceof BinaryOp && ((BinaryOp)input).isOuter() ) { //note: both cases (partitioned matrix, and sorted double array), require to //fit the broadcast twice into the local memory budget. Also, the memory @@ -634,16 +634,13 @@ public class AggUnaryOp extends MultiThreadedHop * * @return true if unary aggregate outer */ - private boolean isUnaryAggregateOuterCPRewriteApplicable() - { + private boolean isUnaryAggregateOuterCPRewriteApplicable() { boolean ret = false; Hop input = getInput().get(0); - - if(( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() ) + if(( input instanceof BinaryOp && ((BinaryOp)input).isOuter() ) && (_op == AggOp.MAXINDEX || _op == AggOp.MININDEX || _op == AggOp.SUM) && (isCompareOperator(((BinaryOp)input).getOp()))) ret = true; - return ret; } http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/BinaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java index 3624db8..80cfbcb 100644 --- a/src/main/java/org/apache/sysml/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java @@ -125,7 +125,7 @@ public class BinaryOp extends MultiThreadedHop outer = flag; } - public boolean isOuterVectorOperator(){ + public boolean isOuter(){ return outer; } http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index d4cb8fc..f17b35d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -336,20 +336,17 @@ public class TemplateCell extends TemplateBase boolean isBinaryMatrixScalar = false; boolean isBinaryMatrixVector = false; boolean isBinaryMatrixMatrix = false; - if( hop instanceof BinaryOp && hop.getDataType().isMatrix() ) { + if( hop instanceof BinaryOp && hop.getDataType().isMatrix() && !((BinaryOp)hop).isOuter() ) { Hop left = hop.getInput().get(0); Hop right = hop.getInput().get(1); - DataType ldt = left.getDataType(); - DataType rdt = right.getDataType(); - - isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar()); + isBinaryMatrixScalar = (left.getDataType().isScalar() || right.getDataType().isScalar()); isBinaryMatrixVector = hop.dimsKnown() - && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) - || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)) ); + && ((left.getDataType().isMatrix() && TemplateUtils.isVectorOrScalar(right)) + || (right.getDataType().isMatrix() && TemplateUtils.isVectorOrScalar(left)) ); isBinaryMatrixMatrix = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) - && ldt.isMatrix() && rdt.isMatrix(); + && left.getDataType().isMatrix() && right.getDataType().isMatrix(); } - + //prepare indicators for ternary operations boolean isTernaryVectorScalarVector = false; boolean isTernaryMatrixScalarMatrixDense = false; http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 3ca15d3..438eb56 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -142,7 +142,7 @@ public class TemplateUtils public static boolean isOperationSupported(Hop h) { if(h instanceof UnaryOp) return UnaryType.contains(((UnaryOp)h).getOp().name()); - else if(h instanceof BinaryOp) + else if(h instanceof BinaryOp && !((BinaryOp)h).isOuter()) return BinType.contains(((BinaryOp)h).getOp().name()); else if(h instanceof TernaryOp) return TernaryType.contains(((TernaryOp)h).getOp().name()); http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java index df44961..859e038 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java @@ -137,7 +137,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass return; if( hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==OpOp2.MULT - && !((BinaryOp) hop).isOuterVectorOperator() + && !((BinaryOp) hop).isOuter() && hop.getInput().get(0).getDataType()==DataType.MATRIX && hop.getInput().get(1) instanceof DataOp && mOnes.containsKey(hop.getInput().get(1).getName()) ) @@ -153,6 +153,6 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass for( Hop c : hop.getInput() ) rRemoveConstantBinaryOp(c, mOnes); - hop.setVisited(); + hop.setVisited(); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 4f0ef51..36864aa 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -2186,7 +2186,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) { //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY - if( hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() + if( hi instanceof BinaryOp && !((BinaryOp) hi).isOuter() && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) ) { BinaryOp bop = (BinaryOp) hi; http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 4396c7b..62a5d4f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -1775,7 +1775,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) //note: this rewrite supports both left/right sequence - if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp)hi).isOuterVectorOperator() ) + if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp)hi).isOuter() ) { if( ( HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a: outer(v, t(seq(1,m)), "==") && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) @@ -1833,7 +1833,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule else { OpOp2 optr = bop2.getComplementPPredOperation(); BinaryOp tmp = HopRewriteUtils.createBinary(bop2.getInput().get(0), - bop2.getInput().get(1), optr, bop2.isOuterVectorOperator()); + bop2.getInput().get(1), optr, bop2.isOuter()); HopRewriteUtils.replaceChildReference(parent, bop, tmp, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2); hi = tmp; http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index 8d962cc..a37ddcc 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -62,6 +62,7 @@ public class CellwiseTmplTest extends AutomatedTestBase private static final String TEST_NAME24 = TEST_NAME+24; //min(X, Y, Z, 3, 7) private static final String TEST_NAME25 = TEST_NAME+25; //bias_add private static final String TEST_NAME26 = TEST_NAME+26; //bias_mult + private static final String TEST_NAME27 = TEST_NAME+27; //outer < +7 negative private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/"; @@ -74,7 +75,7 @@ public class CellwiseTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for( int i=1; i<=26; i++ ) { + for( int i=1; i<=27; i++ ) { addTestConfiguration( TEST_NAME+i, new TestConfiguration( TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) ); } @@ -446,6 +447,20 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwiseRewrite26_sp() { testCodegenIntegration( TEST_NAME26, true, ExecType.SPARK ); } + + @Test + public void testCodegenCellwiseRewrite27() { + testCodegenIntegration( TEST_NAME27, true, ExecType.CP ); + } + + @Test + public void testCodegenCellwise27() { + testCodegenIntegration( TEST_NAME27, false, ExecType.CP ); + } + + public void testCodegenCellwiseRewrite27_sp() { + testCodegenIntegration( TEST_NAME27, true, ExecType.SPARK ); + } private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { @@ -498,7 +513,7 @@ public class CellwiseTmplTest extends AutomatedTestBase } if( !(rewrites && (testname.equals(TEST_NAME2) - || testname.equals(TEST_NAME19))) ) //sigmoid + || testname.equals(TEST_NAME19))) && !testname.equals(TEST_NAME27) ) Assert.assertTrue(heavyHittersContainsSubString( "spoofCell", "sp_spoofCell", "spoofMA", "sp_spoofMA")); if( testname.equals(TEST_NAME7) ) //ensure matrix mult is fused http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/test/scripts/functions/codegen/cellwisetmpl27.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl27.R b/src/test/scripts/functions/codegen/cellwisetmpl27.R new file mode 100644 index 0000000..6f7e7c1 --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl27.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +A = seq(17,1,-1); +C = outer(A, t(A), "<")+7; +S = matrix(as.matrix(C), nrow=17, ncol=17, byrow=FALSE); + +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); + \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/test/scripts/functions/codegen/cellwisetmpl27.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl27.dml b/src/test/scripts/functions/codegen/cellwisetmpl27.dml new file mode 100644 index 0000000..6c3c9b2 --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl27.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = seq(17,1,-1); +C = outer(A, t(A), "<")+7; +write(C, $1)
