Repository: systemml
Updated Branches:
  refs/heads/master 573943e0e -> c89d3be80


[SYSTEMML-2454] Fix codegen binary outer operation handling

So far we generated invalid codegen plans for binary outer vector
operations leading to incorrect results. This patch effectively disables
such outer vector operations (which anyway have dedicated physical
operators that change their asymptotic behavior) in all codegen
templates. Furthermore, this also includes related tests.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c89d3be8
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c89d3be8
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c89d3be8

Branch: refs/heads/master
Commit: c89d3be80c13d47c2545840ce5b33e7debec60a5
Parents: 573943e
Author: Matthias Boehm <[email protected]>
Authored: Wed Jul 18 18:23:51 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jul 18 18:23:51 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  | 11 +++----
 .../java/org/apache/sysml/hops/BinaryOp.java    |  2 +-
 .../hops/codegen/template/TemplateCell.java     | 15 ++++------
 .../hops/codegen/template/TemplateUtils.java    |  2 +-
 .../ipa/IPAPassRemoveConstantBinaryOps.java     |  4 +--
 .../RewriteAlgebraicSimplificationDynamic.java  |  2 +-
 .../RewriteAlgebraicSimplificationStatic.java   |  4 +--
 .../functions/codegen/CellwiseTmplTest.java     | 19 ++++++++++--
 .../scripts/functions/codegen/cellwisetmpl27.R  | 31 ++++++++++++++++++++
 .../functions/codegen/cellwisetmpl27.dml        | 24 +++++++++++++++
 10 files changed, 89 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 4e6cf95..47943d0 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -555,7 +555,7 @@ public class AggUnaryOp extends MultiThreadedHop
                boolean ret = false;
                Hop input = getInput().get(0);
                
-               if( input instanceof BinaryOp && 
((BinaryOp)input).isOuterVectorOperator() )
+               if( input instanceof BinaryOp && ((BinaryOp)input).isOuter() )
                {
                        //for special cases, we need to hold the broadcast 
twice in order to allow for
                        //an efficient binary search over a plain java array
@@ -592,7 +592,7 @@ public class AggUnaryOp extends MultiThreadedHop
                boolean ret = false;
                Hop input = getInput().get(0);
                
-               if( input instanceof BinaryOp && 
((BinaryOp)input).isOuterVectorOperator() )
+               if( input instanceof BinaryOp && ((BinaryOp)input).isOuter() )
                {
                        //note: both cases (partitioned matrix, and sorted 
double array), require to
                        //fit the broadcast twice into the local memory budget. 
Also, the memory 
@@ -634,16 +634,13 @@ public class AggUnaryOp extends MultiThreadedHop
         *   
         * @return true if unary aggregate outer
         */
-       private boolean isUnaryAggregateOuterCPRewriteApplicable() 
-       {
+       private boolean isUnaryAggregateOuterCPRewriteApplicable() {
                boolean ret = false;
                Hop input = getInput().get(0);
-               
-               if(( input instanceof BinaryOp && 
((BinaryOp)input).isOuterVectorOperator() )
+               if(( input instanceof BinaryOp && ((BinaryOp)input).isOuter() )
                        && (_op == AggOp.MAXINDEX || _op == AggOp.MININDEX || 
_op == AggOp.SUM)
                        && (isCompareOperator(((BinaryOp)input).getOp())))
                        ret = true;
-
                return ret;
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/BinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java 
b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index 3624db8..80cfbcb 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -125,7 +125,7 @@ public class BinaryOp extends MultiThreadedHop
                outer = flag;
        }
        
-       public boolean isOuterVectorOperator(){
+       public boolean isOuter(){
                return outer;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index d4cb8fc..f17b35d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -336,20 +336,17 @@ public class TemplateCell extends TemplateBase
                boolean isBinaryMatrixScalar = false;
                boolean isBinaryMatrixVector = false;
                boolean isBinaryMatrixMatrix = false;
-               if( hop instanceof BinaryOp && hop.getDataType().isMatrix() ) {
+               if( hop instanceof BinaryOp && hop.getDataType().isMatrix() && 
!((BinaryOp)hop).isOuter() ) {
                        Hop left = hop.getInput().get(0);
                        Hop right = hop.getInput().get(1);
-                       DataType ldt = left.getDataType();
-                       DataType rdt = right.getDataType();
-                       
-                       isBinaryMatrixScalar = (ldt.isScalar() || 
rdt.isScalar());      
+                       isBinaryMatrixScalar = (left.getDataType().isScalar() 
|| right.getDataType().isScalar());
                        isBinaryMatrixVector = hop.dimsKnown() 
-                               && ((ldt.isMatrix() && 
TemplateUtils.isVectorOrScalar(right)) 
-                               || (rdt.isMatrix() && 
TemplateUtils.isVectorOrScalar(left)) );
+                               && ((left.getDataType().isMatrix() && 
TemplateUtils.isVectorOrScalar(right)) 
+                               || (right.getDataType().isMatrix() && 
TemplateUtils.isVectorOrScalar(left)) );
                        isBinaryMatrixMatrix = hop.dimsKnown() && 
HopRewriteUtils.isEqualSize(left, right)
-                               && ldt.isMatrix() && rdt.isMatrix();
+                               && left.getDataType().isMatrix() && 
right.getDataType().isMatrix();
                }
-                               
+               
                //prepare indicators for ternary operations
                boolean isTernaryVectorScalarVector = false;
                boolean isTernaryMatrixScalarMatrixDense = false;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 3ca15d3..438eb56 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -142,7 +142,7 @@ public class TemplateUtils
        public static boolean isOperationSupported(Hop h) {
                if(h instanceof  UnaryOp)
                        return UnaryType.contains(((UnaryOp)h).getOp().name());
-               else if(h instanceof BinaryOp)
+               else if(h instanceof BinaryOp && !((BinaryOp)h).isOuter())
                        return BinType.contains(((BinaryOp)h).getOp().name());
                else if(h instanceof TernaryOp)
                        return 
TernaryType.contains(((TernaryOp)h).getOp().name());

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
index df44961..859e038 100644
--- 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
@@ -137,7 +137,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
                        return;
 
                if( hop instanceof BinaryOp && 
((BinaryOp)hop).getOp()==OpOp2.MULT
-                       && !((BinaryOp) hop).isOuterVectorOperator()
+                       && !((BinaryOp) hop).isOuter()
                        && hop.getInput().get(0).getDataType()==DataType.MATRIX
                        && hop.getInput().get(1) instanceof DataOp
                        && mOnes.containsKey(hop.getInput().get(1).getName()) )
@@ -153,6 +153,6 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
                for( Hop c : hop.getInput() )
                        rRemoveConstantBinaryOp(c, mOnes);
        
-               hop.setVisited();               
+               hop.setVisited();
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 4f0ef51..36864aa 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2186,7 +2186,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
        private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int 
pos) 
        {
                //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X 
- s*Y -> X -* sY
-               if( hi instanceof BinaryOp && !((BinaryOp) 
hi).isOuterVectorOperator()
+               if( hi instanceof BinaryOp && !((BinaryOp) hi).isOuter()
                        && (((BinaryOp)hi).getOp()==OpOp2.PLUS || 
((BinaryOp)hi).getOp()==OpOp2.MINUS) )
                {
                        BinaryOp bop = (BinaryOp) hi;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 4396c7b..62a5d4f 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1775,7 +1775,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                //pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, 
dir=row, ignore=true, cast=false)
                //note: this rewrite supports both left/right sequence 
                
-               if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && 
((BinaryOp)hi).isOuterVectorOperator() )
+               if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && 
((BinaryOp)hi).isOuter() )
                {
                        if(   ( 
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a: 
outer(v, t(seq(1,m)), "==")
                                    && 
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) 
@@ -1833,7 +1833,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        else {
                                OpOp2 optr = bop2.getComplementPPredOperation();
                                BinaryOp tmp = 
HopRewriteUtils.createBinary(bop2.getInput().get(0),
-                                       bop2.getInput().get(1), optr, 
bop2.isOuterVectorOperator());
+                                       bop2.getInput().get(1), optr, 
bop2.isOuter());
                                HopRewriteUtils.replaceChildReference(parent, 
bop, tmp, pos);
                                HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                                hi = tmp;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
index 8d962cc..a37ddcc 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
@@ -62,6 +62,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
        private static final String TEST_NAME24 = TEST_NAME+24; //min(X, Y, Z, 
3, 7)
        private static final String TEST_NAME25 = TEST_NAME+25; //bias_add
        private static final String TEST_NAME26 = TEST_NAME+26; //bias_mult
+       private static final String TEST_NAME27 = TEST_NAME+27; //outer < +7 
negative
 
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
CellwiseTmplTest.class.getSimpleName() + "/";
@@ -74,7 +75,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for( int i=1; i<=26; i++ ) {
+               for( int i=1; i<=27; i++ ) {
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(
                                TEST_CLASS_DIR, TEST_NAME+i, new String[] 
{String.valueOf(i)}) );
                }
@@ -446,6 +447,20 @@ public class CellwiseTmplTest extends AutomatedTestBase
        public void testCodegenCellwiseRewrite26_sp() {
                testCodegenIntegration( TEST_NAME26, true, ExecType.SPARK );
        }
+       
+       @Test
+       public void testCodegenCellwiseRewrite27() {
+               testCodegenIntegration( TEST_NAME27, true, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenCellwise27() {
+               testCodegenIntegration( TEST_NAME27, false, ExecType.CP );
+       }
+
+       public void testCodegenCellwiseRewrite27_sp() {
+               testCodegenIntegration( TEST_NAME27, true, ExecType.SPARK );
+       }
 
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {
@@ -498,7 +513,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
                        }
                        
                        if( !(rewrites && (testname.equals(TEST_NAME2)
-                               || testname.equals(TEST_NAME19))) ) //sigmoid
+                               || testname.equals(TEST_NAME19))) && 
!testname.equals(TEST_NAME27) )
                                Assert.assertTrue(heavyHittersContainsSubString(
                                                "spoofCell", "sp_spoofCell", 
"spoofMA", "sp_spoofMA"));
                        if( testname.equals(TEST_NAME7) ) //ensure matrix mult 
is fused

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/test/scripts/functions/codegen/cellwisetmpl27.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl27.R 
b/src/test/scripts/functions/codegen/cellwisetmpl27.R
new file mode 100644
index 0000000..6f7e7c1
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl27.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A = seq(17,1,-1);
+C = outer(A, t(A), "<")+7;
+S = matrix(as.matrix(C), nrow=17, ncol=17, byrow=FALSE);
+
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); 
+ 
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/c89d3be8/src/test/scripts/functions/codegen/cellwisetmpl27.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl27.dml 
b/src/test/scripts/functions/codegen/cellwisetmpl27.dml
new file mode 100644
index 0000000..6c3c9b2
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl27.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = seq(17,1,-1);
+C = outer(A, t(A), "<")+7;
+write(C, $1)

Reply via email to