[SYSTEMML-2082] Codegen support for ternary ifelse in cell/magg tmpls

This patch adds basic support for ternary ifelse operations in codegen
cell and magg templates along with related tests.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/aa537dad
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/aa537dad
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/aa537dad

Branch: refs/heads/master
Commit: aa537dad43f2cf21badaedcb8629b27ad301032b
Parents: 5457066
Author: Matthias Boehm <[email protected]>
Authored: Tue Feb 6 20:06:28 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Tue Feb 6 20:06:28 2018 -0800

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeTernary.java  | 23 ++++++++------
 .../hops/codegen/template/TemplateCell.java     | 12 +++++---
 .../functions/codegen/CellwiseTmplTest.java     | 18 ++++++++++-
 .../scripts/functions/codegen/cellwisetmpl18.R  | 32 ++++++++++++++++++++
 .../functions/codegen/cellwisetmpl18.dml        | 30 ++++++++++++++++++
 5 files changed, 99 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
index 155cc8b..dc8ff82 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -27,7 +27,7 @@ public class CNodeTernary extends CNode
 {
        public enum TernaryType {
                PLUS_MULT, MINUS_MULT,
-               REPLACE, REPLACE_NAN,
+               REPLACE, REPLACE_NAN, IFELSE,
                LOOKUP_RC1, LOOKUP_RVECT1;
                
                
@@ -52,7 +52,10 @@ public class CNodeTernary extends CNode
                                
                                case REPLACE_NAN:
                                        return "    double %TMP% = 
Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
-                                       
+                               
+                               case IFELSE:
+                                       return "    double %TMP% = (%IN1% != 0) 
? %IN2% : %IN3%;\n";
+                               
                                case LOOKUP_RC1:
                                        return sparse ?
                                                "    double %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
@@ -124,15 +127,14 @@ public class CNodeTernary extends CNode
        @Override
        public String toString() {
                switch(_type) {
-                       case PLUS_MULT: return "t(+*)";
-                       case MINUS_MULT: return "t(-*)";
-                       case REPLACE: 
-                       case REPLACE_NAN: return "t(rplc)";
-                       case LOOKUP_RC1: return "u(ixrc1)";
+                       case PLUS_MULT:     return "t(+*)";
+                       case MINUS_MULT:    return "t(-*)";
+                       case REPLACE:
+                       case REPLACE_NAN:   return "t(rplc)";
+                       case IFELSE:        return "t(ifelse)";
+                       case LOOKUP_RC1:    return "u(ixrc1)";
                        case LOOKUP_RVECT1: return "u(ixrv1)";
-                       
-                       default:
-                               return super.toString();        
+                       default:            return super.toString();
                }
        }
        
@@ -143,6 +145,7 @@ public class CNodeTernary extends CNode
                        case MINUS_MULT:
                        case REPLACE:
                        case REPLACE_NAN:
+                       case IFELSE:
                        case LOOKUP_RC1:
                                _rows = 0;
                                _cols = 0;

http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 50b42ea..2b8db2a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -34,6 +34,7 @@ import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
@@ -168,7 +169,7 @@ public class TemplateCell extends TemplateBase
                                        && 
HopRewriteUtils.isMatrixMultiply(hop) && i==0 ) //skip transpose
                                rConstructCplan(c.getInput().get(0), memo, tmp, 
inHops, compileLiterals);
                        else {
-                               CNodeData cdata = 
TemplateUtils.createCNodeData(c, compileLiterals);    
+                               CNodeData cdata = 
TemplateUtils.createCNodeData(c, compileLiterals);
                                tmp.put(c.getHopID(), cdata);
                                inHops.add(c);
                        }
@@ -208,6 +209,7 @@ public class TemplateCell extends TemplateBase
                        
                        //add lookups if required
                        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, 
hop.getInput().get(0));
+                       cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, 
hop.getInput().get(1));
                        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, 
hop.getInput().get(2));
                        
                        //construct ternary cnode, primitive operation derived 
from OpOp3
@@ -299,11 +301,11 @@ public class TemplateCell extends TemplateBase
                //prepare indicators for ternary operations
                boolean isTernaryVectorScalarVector = false;
                boolean isTernaryMatrixScalarMatrixDense = false;
+               boolean isTernaryIfElse = (HopRewriteUtils.isTernary(hop, 
OpOp3.IFELSE) && hop.getDataType().isMatrix());
                if( hop instanceof TernaryOp && hop.getInput().size()==3 && 
hop.dimsKnown() 
-                       && HopRewriteUtils.checkInputDataTypes(hop, 
DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
+                       && HopRewriteUtils.checkInputDataTypes(hop, 
DataType.MATRIX, DataType.SCALAR, DataType.MATRIX) ) {
                        Hop left = hop.getInput().get(0);
                        Hop right = hop.getInput().get(2);
-                       
                        isTernaryVectorScalarVector = 
TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
                        isTernaryMatrixScalarMatrixDense = 
HopRewriteUtils.isEqualSize(left, right) 
                                && !HopRewriteUtils.isSparse(left) && 
!HopRewriteUtils.isSparse(right);
@@ -312,8 +314,8 @@ public class TemplateCell extends TemplateBase
                //check supported unary, binary, ternary operations
                return hop.getDataType() == DataType.MATRIX && 
TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp 
                                || isBinaryMatrixScalar || isBinaryMatrixVector 
|| isBinaryMatrixMatrix
-                               || isTernaryVectorScalarVector || 
isTernaryMatrixScalarMatrixDense
-                               || (hop instanceof ParameterizedBuiltinOp && 
((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));   
+                               || isTernaryVectorScalarVector || 
isTernaryMatrixScalarMatrixDense || isTernaryIfElse
+                               || (hop instanceof ParameterizedBuiltinOp && 
((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));
        }
        
        protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, 
List<CNode> outputs, List<AggOp> aggOps, boolean onlySum) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
index bd3b36a..2f44f61 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
@@ -53,6 +53,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
        private static final String TEST_NAME15 = TEST_NAME+15; 
//colMins(2*log(X))
        private static final String TEST_NAME16 = TEST_NAME+16; 
//colSums(2*log(X));
        private static final String TEST_NAME17 = TEST_NAME+17; //xor operation
+       private static final String TEST_NAME18 = TEST_NAME+18; 
//sum(ifelse(X,Y,Z))
        
        
        private static final String TEST_DIR = "functions/codegen/";
@@ -66,7 +67,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for( int i=1; i<=17; i++ ) {
+               for( int i=1; i<=18; i++ ) {
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(
                                        TEST_CLASS_DIR, TEST_NAME+i, new 
String[] {String.valueOf(i)}) );
                }
@@ -304,6 +305,21 @@ public class CellwiseTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME17, true, ExecType.SPARK );
        }
        
+       @Test
+       public void testCodegenCellwiseRewrite18() {
+               testCodegenIntegration( TEST_NAME18, true, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenCellwise18() {
+               testCodegenIntegration( TEST_NAME18, false, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenCellwiseRewrite18_sp() {
+               testCodegenIntegration( TEST_NAME18, true, ExecType.SPARK );
+       }
+       
        
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {                       

http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/scripts/functions/codegen/cellwisetmpl18.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl18.R 
b/src/test/scripts/functions/codegen/cellwisetmpl18.R
new file mode 100644
index 0000000..e6a275a
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl18.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(seq(-1000, 198999), 1000, 200, byrow=TRUE);
+Y = matrix(seq(0, 199999), 1000, 200, byrow=TRUE);
+Z = matrix(seq(1000, 200999), 1000, 200, byrow=TRUE);
+
+R = as.matrix(sum(as.numeric(ifelse(X,Y,Z))));
+
+writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/scripts/functions/codegen/cellwisetmpl18.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl18.dml 
b/src/test/scripts/functions/codegen/cellwisetmpl18.dml
new file mode 100644
index 0000000..c178dd3
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl18.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(-1000, 198999), 1000, 200);
+Y = matrix(seq(0, 199999), 1000, 200);
+Z = matrix(seq(1000, 200999), 1000, 200);
+
+while(FALSE){}
+
+R = as.matrix(sum(ifelse(X,Y,Z)));
+
+write(R, $1)

Reply via email to