Repository: incubator-systemml
Updated Branches:
  refs/heads/master 686363208 -> 19e21744c


[SYSTEMML-1590] Fix codegen handling of unsupported row aggregates

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/19e21744
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/19e21744
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/19e21744

Branch: refs/heads/master
Commit: 19e21744c86adbedf6098906808c2c6327659cfe
Parents: 6863632
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sun May 7 20:49:38 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sun May 7 20:49:46 2017 -0700

----------------------------------------------------------------------
 .../hops/codegen/template/TemplateRow.java      |  6 ++--
 .../functions/codegen/RowAggTmplTest.java       | 18 ++++++++++-
 .../scripts/functions/codegen/rowAggPattern16.R | 33 ++++++++++++++++++++
 .../functions/codegen/rowAggPattern16.dml       | 27 ++++++++++++++++
 4 files changed, 81 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 3af8be4..3f947c8 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -78,7 +78,8 @@ public class TemplateRow extends TemplateBase
                        || (hop instanceof AggBinaryOp && hop.getDim2()==1
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol 
-                               && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1);
+                               && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1
+                               && HopRewriteUtils.isAggUnaryOp(hop, 
SUPPORTED_ROW_AGG));
        }
 
        @Override
@@ -89,7 +90,8 @@ public class TemplateRow extends TemplateBase
                                        || 
HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) ) 
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp) 
                                        && TemplateCell.isValidOperation(hop))  
        
-                       || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol)
+                       || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol
+                               && HopRewriteUtils.isAggUnaryOp(hop, 
SUPPORTED_ROW_AGG))
                        || (hop instanceof AggBinaryOp && hop.getDim1()>1 
                                && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 4037edb..b7f82a7 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -51,6 +51,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME13 = TEST_NAME+"13"; 
//rowSums(X)+rowSums(Y)
        private static final String TEST_NAME14 = TEST_NAME+"14"; 
//colSums(max(floor(round(abs(min(sign(X+Y),1)))),7))
        private static final String TEST_NAME15 = TEST_NAME+"15"; //systemml nn 
- softmax backward (partially)
+       private static final String TEST_NAME16 = TEST_NAME+"16"; 
//Y=X-rowIndexMax(X); R=Y/rowSums(Y)
        
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
@@ -62,7 +63,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<=15; i++)
+               for(int i=1; i<=16; i++)
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
@@ -291,6 +292,21 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME15, false, ExecType.SPARK );
        }
        
+       @Test   
+       public void testCodegenRowAggRewrite16CP() {
+               testCodegenIntegration( TEST_NAME16, true, ExecType.CP );
+       }
+       
+       @Test
+       public void testCodegenRowAgg16CP() {
+               testCodegenIntegration( TEST_NAME16, false, ExecType.CP );
+       }
+       
+       @Test
+       public void testCodegenRowAgg16SP() {
+               testCodegenIntegration( TEST_NAME16, false, ExecType.SPARK );
+       }
+       
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {       
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/scripts/functions/codegen/rowAggPattern16.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern16.R 
b/src/test/scripts/functions/codegen/rowAggPattern16.R
new file mode 100644
index 0000000..a4e9184
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern16.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+
+X = matrix(seq(1,1500), 150, 10, byrow=TRUE);
+
+Y1 = X - max.col(X, ties.method="last") 
+R = Y1 / rowSums(Y1)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/scripts/functions/codegen/rowAggPattern16.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern16.dml 
b/src/test/scripts/functions/codegen/rowAggPattern16.dml
new file mode 100644
index 0000000..e0558f6
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern16.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1,1500), rows=150, cols=10);
+
+Y1 = X - rowIndexMax(X) 
+R = Y1 / rowSums(Y1)
+
+write(R, $1)

Reply via email to