Repository: systemml
Updated Branches:
  refs/heads/master addd6e121 -> 0abeb60b3


[MINOR] Additional tests for row/col means/vars and matrix reshapes

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0abeb60b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0abeb60b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0abeb60b

Branch: refs/heads/master
Commit: 0abeb60b3c70925adb1b4e3ee8e4e4e42aa5f316
Parents: addd6e1
Author: Matthias Boehm <[email protected]>
Authored: Sun Apr 1 13:53:10 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Apr 1 13:53:10 2018 -0700

----------------------------------------------------------------------
 .../functions/misc/RewriteNNIssueTest.java      | 86 ++++++++++++++++++++
 .../scripts/functions/misc/RewriteNNIssue.R     | 49 +++++++++++
 .../scripts/functions/misc/RewriteNNIssue.dml   | 43 ++++++++++
 .../functions/misc/ZPackageSuite.java           |  1 +
 4 files changed, 179 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0abeb60b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteNNIssueTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteNNIssueTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteNNIssueTest.java
new file mode 100644
index 0000000..55c440b
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteNNIssueTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import org.junit.Test;
+
+import java.util.HashMap;
+
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class RewriteNNIssueTest extends AutomatedTestBase 
+{
+       private static final String TEST_NAME = "RewriteNNIssue";
+       
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteNNIssueTest.class.getSimpleName() + "/";
+       
+       private double eps = Math.pow(10, -10);
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }) );
+       }
+       
+       @Test
+       public void testNNIssueRewrite() {
+               runNNIssueTest(true);
+       }
+       
+       @Test
+       public void testNNIssueNoRewrite() {
+               runNNIssueTest(false);
+       }
+       
+       private void runNNIssueTest(boolean rewrites)
+       {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{ "-stats","-args", 
output("R") };
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       rCmd = getRCmd(expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+                       
+                       //run test
+                       runTest(true, false, null, -1); 
+                       runRScript(true); 
+                       
+                       //compare matrices 
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("R");
+                       HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0abeb60b/src/test/scripts/functions/misc/RewriteNNIssue.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteNNIssue.R 
b/src/test/scripts/functions/misc/RewriteNNIssue.R
new file mode 100644
index 0000000..2f1f13a
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteNNIssue.R
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library(Matrix)
+library(matrixStats)
+
+N = 2
+C = 2
+Hin = 3
+Win = 4
+
+X = matrix(cbind(seq(1,20),seq(1,20),seq(1,8)), nrow=2, ncol=24, byrow=TRUE)
+gamma = matrix(c(1,2), byrow=TRUE, nrow=2, ncol=1)
+beta = matrix(c(0,1), byrow=TRUE, nrow=2, ncol=1)
+ema_mean = matrix(c(4,5), byrow=TRUE, nrow=2, ncol=1)
+ema_var = matrix(c(2,3), byrow=TRUE, nrow=2, ncol=1)
+mu = 0.95
+epsilon = 1e-4
+
+subgrp_means = matrix(colMeans(X), nrow=C, ncol=Hin*Win, byrow=TRUE)
+subgrp_vars = matrix(colVars(X) * ((N-1)/N), nrow=C, ncol=Hin*Win, byrow=TRUE)
+mean = rowMeans(subgrp_means)  # shape (C, 1)
+var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ema_var_upd = mu*ema_var + (1-mu)*var
+
+R = cbind(mean, var, ema_mean_upd, ema_var_upd)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/0abeb60b/src/test/scripts/functions/misc/RewriteNNIssue.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteNNIssue.dml 
b/src/test/scripts/functions/misc/RewriteNNIssue.dml
new file mode 100644
index 0000000..56a99cf
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteNNIssue.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+N = 2
+C = 2
+Hin = 3
+Win = 4
+X = matrix(rbind(seq(1,20),seq(1,20),seq(1,8)), rows=2, cols=24)
+gamma = matrix("1 2", rows=2, cols=1)
+beta = matrix("0 1", rows=2, cols=1)
+ema_mean = matrix("4 5", rows=2, cols=1)
+ema_var = matrix("2 3", rows=2, cols=1)
+mu = 0.95
+epsilon = 1e-4
+
+subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
+mean = rowMeans(subgrp_means)  # shape (C, 1)
+var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ema_var_upd = mu*ema_var + (1-mu)*var
+
+R = cbind(mean, var, ema_mean_upd, ema_var_upd)
+
+write(R,$1)

http://git-wip-us.apache.org/repos/asf/systemml/blob/0abeb60b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index 46385c2..b75b07a 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -65,6 +65,7 @@ import org.junit.runners.Suite;
        RewriteLoopVectorization.class,
        RewriteMatrixMultChainOptTest.class,
        RewriteMergeBlocksTest.class,
+       RewriteNNIssueTest.class,
        RewritePushdownSumBinaryMult.class,
        RewritePushdownSumOnBinaryTest.class,
        RewritePushdownUaggTest.class,

Reply via email to