Repository: incubator-systemml
Updated Branches:
  refs/heads/master 8f7cf77be -> 2e48d951b


[SYSTEMML-1447] Extended code generator (replace in rowagg/cell tmpls)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2e48d951
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2e48d951
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2e48d951

Branch: refs/heads/master
Commit: 2e48d951b825fe4ef85dc13f6d69934b8cadfe46
Parents: 8f7cf77
Author: Matthias Boehm <[email protected]>
Authored: Fri Mar 31 17:17:55 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Mar 31 18:21:17 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/ParameterizedBuiltinOp.java      |  5 ++++
 .../sysml/hops/codegen/cplan/CNodeTernary.java  | 16 ++++++++--
 .../hops/codegen/template/TemplateCell.java     | 25 +++++++++++++---
 .../hops/codegen/template/TemplateRowAgg.java   | 18 +++++++++++-
 .../hops/codegen/template/TemplateUtils.java    |  3 ++
 .../functions/codegen/CellwiseTmplTest.java     | 22 ++++++++++++--
 .../scripts/functions/codegen/cellwisetmpl11.R  | 31 ++++++++++++++++++++
 .../functions/codegen/cellwisetmpl11.dml        | 27 +++++++++++++++++
 8 files changed, 138 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index fa51948..1d6828c 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -154,6 +154,11 @@ public class ParameterizedBuiltinOp extends Hop implements 
MultiThreadedHop
                        getInput().get(_paramIndexMap.get("target")) : null;
        }
        
+       public Hop getParameterHop(String name) {
+               return _paramIndexMap.containsKey(name) ?   
+                       getInput().get(_paramIndexMap.get(name)) : null;        
+       }
+       
        @Override
        public void setMaxNumThreads( int k ) {
                _maxNumThreads = k;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
index eb26eff..a8bbcb2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -28,6 +28,7 @@ public class CNodeTernary extends CNode
 {
        public enum TernaryType {
                PLUS_MULT, MINUS_MULT,
+               REPLACE, REPLACE_NAN,
                LOOKUP_RC1;
                
                public static boolean contains(String value) {
@@ -40,10 +41,17 @@ public class CNodeTernary extends CNode
                public String getTemplate(boolean sparse) {
                        switch (this) {
                                case PLUS_MULT:
-                                       return "    double %TMP% = %IN1% + 
%IN2% * %IN3%;\n" ;
+                                       return "    double %TMP% = %IN1% + 
%IN2% * %IN3%;\n";
                                
                                case MINUS_MULT:
-                                       return "    double %TMP% = %IN1% - 
%IN2% * %IN3%;\n" ;
+                                       return "    double %TMP% = %IN1% - 
%IN2% * %IN3%;\n";
+                                       
+                               case REPLACE:
+                                       return "    double %TMP% = (%IN1% == 
%IN2% || (Double.isNaN(%IN1%) "
+                                                       + "&& 
Double.isNaN(%IN2%))) ? %IN3% : %IN1%;\n";
+                               
+                               case REPLACE_NAN:
+                                       return "    double %TMP% = 
Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
                                        
                                case LOOKUP_RC1:
                                        return "    double %TMP% = 
%IN1%[rowIndex*%IN2%+%IN3%-1];\n";   
@@ -101,6 +109,8 @@ public class CNodeTernary extends CNode
                switch(_type) {
                        case PLUS_MULT: return "t(+*)";
                        case MINUS_MULT: return "t(-*)";
+                       case REPLACE: 
+                       case REPLACE_NAN: return "t(rplc)";
                        case LOOKUP_RC1: return "u(ixrc1)";
                        default:
                                return super.toString();        
@@ -112,6 +122,8 @@ public class CNodeTernary extends CNode
                switch(_type) {
                        case PLUS_MULT: 
                        case MINUS_MULT:
+                       case REPLACE:
+                       case REPLACE_NAN:
                        case LOOKUP_RC1:
                                _rows = 0;
                                _cols = 0;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 87ec899..447f6d6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -34,8 +34,10 @@ import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
@@ -157,7 +159,7 @@ public class TemplateCell extends TemplateBase
                        else if( cdata1 instanceof CNodeData && 
hop.getInput().get(0).getDataType().isMatrix() )
                                cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP_RC);
                        
-                       String primitiveOpName = 
((UnaryOp)hop).getOp().toString();
+                       String primitiveOpName = ((UnaryOp)hop).getOp().name();
                        out = new CNodeUnary(cdata1, 
UnaryType.valueOf(primitiveOpName));
                }
                else if(hop instanceof BinaryOp)
@@ -165,7 +167,7 @@ public class TemplateCell extends TemplateBase
                        BinaryOp bop = (BinaryOp) hop;
                        CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());
                        CNode cdata2 = 
tmp.get(hop.getInput().get(1).getHopID());
-                       String primitiveOpName = bop.getOp().toString();
+                       String primitiveOpName = bop.getOp().name();
                        
                        //cdata1 is vector
                        if( TemplateUtils.isColVector(cdata1) )
@@ -207,7 +209,21 @@ public class TemplateCell extends TemplateBase
                        
                        //construct ternary cnode, primitive operation derived 
from OpOp3
                        out = new CNodeTernary(cdata1, cdata2, cdata3, 
-                                       
TernaryType.valueOf(top.getOp().toString()));
+                                       
TernaryType.valueOf(top.getOp().name()));
+               }
+               else if( hop instanceof ParameterizedBuiltinOp ) 
+               {
+                       CNode cdata1 = 
tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
+                       if( TemplateUtils.isColVector(cdata1) )
+                               cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP_R);
+                       else if( cdata1 instanceof CNodeData && 
hop.getInput().get(0).getDataType().isMatrix() )
+                               cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP_RC);
+                       
+                       CNode cdata2 = 
tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
+                       CNode cdata3 = 
tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
+                       TernaryType ttype = (cdata2.isLiteral() && 
cdata2.getVarname().equals("Double.NaN")) ? 
+                                       TernaryType.REPLACE_NAN : 
TernaryType.REPLACE;
+                       out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
                }
                else if( hop instanceof IndexingOp ) 
                {
@@ -285,7 +301,8 @@ public class TemplateCell extends TemplateBase
                //check supported unary, binary, ternary operations
                return hop.getDataType() == DataType.MATRIX && 
TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp 
                                || isBinaryMatrixScalar || isBinaryMatrixVector 
|| isBinaryMatrixMatrixDense 
-                               || isTernaryVectorScalarVector || 
isTernaryMatrixScalarMatrixDense);    
+                               || isTernaryVectorScalarVector || 
isTernaryMatrixScalarMatrixDense
+                               || (hop instanceof ParameterizedBuiltinOp && 
((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));   
        }
        
        /**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
index f8f1508..2883893 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
@@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
@@ -78,7 +79,8 @@ public class TemplateRowAgg extends TemplateBase
                return !isClosed() && 
                        (  (hop instanceof BinaryOp && 
(HopRewriteUtils.isBinaryMatrixColVectorOperation(hop)
                                        || 
HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) ) 
-                       || (hop instanceof UnaryOp && 
TemplateCell.isValidOperation(hop))               
+                       || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp) 
+                                       && TemplateCell.isValidOperation(hop))  
        
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol)
                        || (hop instanceof AggBinaryOp && hop.getDim1()>1 
                                && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
@@ -255,6 +257,20 @@ public class TemplateRowAgg extends TemplateBase
                        out = new CNodeTernary(cdata1, cdata2, cdata3, 
                                        
TernaryType.valueOf(top.getOp().toString()));
                }
+               else if( hop instanceof ParameterizedBuiltinOp ) 
+               {
+                       CNode cdata1 = 
tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
+                       if( TemplateUtils.isColVector(cdata1) )
+                               cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP_R);
+                       else if( cdata1 instanceof CNodeData && 
hop.getInput().get(0).getDataType().isMatrix() )
+                               cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP_RC);
+                       
+                       CNode cdata2 = 
tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
+                       CNode cdata3 = 
tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
+                       TernaryType ttype = (cdata2.isLiteral() && 
cdata2.getVarname().equals("Double.NaN")) ? 
+                                       TernaryType.REPLACE_NAN : 
TernaryType.REPLACE;
+                       out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
+               }
                else if( hop instanceof IndexingOp ) 
                {
                        CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 3f5fed9..b959638 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -30,6 +30,7 @@ import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
@@ -105,6 +106,8 @@ public class TemplateUtils
                        return BinType.contains(((BinaryOp)h).getOp().name());
                else if(h instanceof TernaryOp)
                        return 
TernaryType.contains(((TernaryOp)h).getOp().name());
+               else if(h instanceof ParameterizedBuiltinOp) 
+                       return 
TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name());
                return false;
        }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
index 066b761..10aa038 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
@@ -46,6 +46,8 @@ public class CellwiseTmplTest extends AutomatedTestBase
        private static final String TEST_NAME8 = TEST_NAME+8;
        private static final String TEST_NAME9 = TEST_NAME+9;   //sum((X + 7 * 
Y)^2)
        private static final String TEST_NAME10 = TEST_NAME+10; //min/max(X + 7 
* Y)
+       private static final String TEST_NAME11 = TEST_NAME+11; //replace((0 / 
(X - 500))+1, 0/0, 7);
+       
 
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
CellwiseTmplTest.class.getSimpleName() + "/";
@@ -58,7 +60,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for( int i=1; i<=10; i++ ) {
+               for( int i=1; i<=11; i++ ) {
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(
                                        TEST_CLASS_DIR, TEST_NAME+i, new 
String[] {String.valueOf(i)}) );
                }
@@ -114,6 +116,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
        public void testCodegenCellwiseRewrite10() {
                testCodegenIntegration( TEST_NAME10, true, ExecType.CP  );
        }
+       
+       @Test
+       public void testCodegenCellwiseRewrite11() {
+               testCodegenIntegration( TEST_NAME11, true, ExecType.CP  );
+       }
 
        @Test
        public void testCodegenCellwise1() {
@@ -165,6 +172,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
        public void testCodegenCellwise10() {
                testCodegenIntegration( TEST_NAME10, false, ExecType.CP  );
        }
+       
+       @Test
+       public void testCodegenCellwise11() {
+               testCodegenIntegration( TEST_NAME11, false, ExecType.CP  );
+       }
 
        @Test
        public void testCodegenCellwiseRewrite1_sp() {
@@ -191,6 +203,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME10, true, ExecType.SPARK );
        }
        
+       @Test
+       public void testCodegenCellwiseRewrite11_sp() {
+               testCodegenIntegration( TEST_NAME11, true, ExecType.SPARK );
+       }
+       
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {       
                
@@ -247,7 +264,8 @@ public class CellwiseTmplTest extends AutomatedTestBase
                                
Assert.assertTrue(!heavyHittersContainsSubString("tsmm"));
                        else if( testname.equals(TEST_NAME10) ) //ensure 
min/max is fused
                                
Assert.assertTrue(!heavyHittersContainsSubString("uamin","uamax"));
-                               
+                       else if( testname.equals(TEST_NAME11) ) //ensure 
replace is fused
+                               
Assert.assertTrue(!heavyHittersContainsSubString("replace"));   
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
oldRewrites;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/scripts/functions/codegen/cellwisetmpl11.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl11.R 
b/src/test/scripts/functions/codegen/cellwisetmpl11.R
new file mode 100644
index 0000000..33531ba
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl11.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(seq(7, 1006), 500, 2, byrow=TRUE);
+
+Y = (0 / (X - 500))+1;
+R = replace(Y, is.nan(Y), 7);
+
+writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/scripts/functions/codegen/cellwisetmpl11.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl11.dml 
b/src/test/scripts/functions/codegen/cellwisetmpl11.dml
new file mode 100644
index 0000000..c77da08
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl11.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(7, 1006), 500, 2);
+
+Y = (0 / (X - 500))+1;
+R = replace(target=Y, pattern=0/0, replacement=7);
+
+write(R, $1)

Reply via email to