Repository: systemml Updated Branches: refs/heads/master addd6e121 -> 0abeb60b3
[MINOR] Additional tests for row/col means/vars and matrix reshapes Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0abeb60b Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0abeb60b Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0abeb60b Branch: refs/heads/master Commit: 0abeb60b3c70925adb1b4e3ee8e4e4e42aa5f316 Parents: addd6e1 Author: Matthias Boehm <[email protected]> Authored: Sun Apr 1 13:53:10 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Apr 1 13:53:10 2018 -0700 ---------------------------------------------------------------------- .../functions/misc/RewriteNNIssueTest.java | 86 ++++++++++++++++++++ .../scripts/functions/misc/RewriteNNIssue.R | 49 +++++++++++ .../scripts/functions/misc/RewriteNNIssue.dml | 43 ++++++++++ .../functions/misc/ZPackageSuite.java | 1 + 4 files changed, 179 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0abeb60b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteNNIssueTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteNNIssueTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteNNIssueTest.java new file mode 100644 index 0000000..55c440b --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteNNIssueTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import org.junit.Test; + +import java.util.HashMap; + +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class RewriteNNIssueTest extends AutomatedTestBase +{ + private static final String TEST_NAME = "RewriteNNIssue"; + + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteNNIssueTest.class.getSimpleName() + "/"; + + private double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }) ); + } + + @Test + public void testNNIssueRewrite() { + runNNIssueTest(true); + } + + @Test + public void testNNIssueNoRewrite() { + runNNIssueTest(false); + } + + private void runNNIssueTest(boolean rewrites) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{ "-stats","-args", output("R") }; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + //run test + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/0abeb60b/src/test/scripts/functions/misc/RewriteNNIssue.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteNNIssue.R b/src/test/scripts/functions/misc/RewriteNNIssue.R new file mode 100644 index 0000000..2f1f13a --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteNNIssue.R @@ -0,0 +1,49 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library(Matrix) +library(matrixStats) + +N = 2 +C = 2 +Hin = 3 +Win = 4 + +X = matrix(cbind(seq(1,20),seq(1,20),seq(1,8)), nrow=2, ncol=24, byrow=TRUE) +gamma = matrix(c(1,2), byrow=TRUE, nrow=2, ncol=1) +beta = matrix(c(0,1), byrow=TRUE, nrow=2, ncol=1) +ema_mean = matrix(c(4,5), byrow=TRUE, nrow=2, ncol=1) +ema_var = matrix(c(2,3), byrow=TRUE, nrow=2, ncol=1) +mu = 0.95 +epsilon = 1e-4 + +subgrp_means = matrix(colMeans(X), nrow=C, ncol=Hin*Win, byrow=TRUE) +subgrp_vars = matrix(colVars(X) * ((N-1)/N), nrow=C, ncol=Hin*Win, byrow=TRUE) +mean = rowMeans(subgrp_means) # shape (C, 1) +var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) +ema_mean_upd = mu*ema_mean + (1-mu)*mean +ema_var_upd = mu*ema_var + (1-mu)*var + +R = cbind(mean, var, ema_mean_upd, ema_var_upd) + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/0abeb60b/src/test/scripts/functions/misc/RewriteNNIssue.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteNNIssue.dml b/src/test/scripts/functions/misc/RewriteNNIssue.dml new file mode 100644 index 0000000..56a99cf --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteNNIssue.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +N = 2 +C = 2 +Hin = 3 +Win = 4 +X = matrix(rbind(seq(1,20),seq(1,20),seq(1,8)), rows=2, cols=24) +gamma = matrix("1 2", rows=2, cols=1) +beta = matrix("0 1", rows=2, cols=1) +ema_mean = matrix("4 5", rows=2, cols=1) +ema_var = matrix("2 3", rows=2, cols=1) +mu = 0.95 +epsilon = 1e-4 + +subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win) +subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) +mean = rowMeans(subgrp_means) # shape (C, 1) +var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) +ema_mean_upd = mu*ema_mean + (1-mu)*mean +ema_var_upd = mu*ema_var + (1-mu)*var + +R = cbind(mean, var, ema_mean_upd, ema_var_upd) + +write(R,$1) http://git-wip-us.apache.org/repos/asf/systemml/blob/0abeb60b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index 46385c2..b75b07a 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -65,6 +65,7 @@ import org.junit.runners.Suite; RewriteLoopVectorization.class, RewriteMatrixMultChainOptTest.class, RewriteMergeBlocksTest.class, + RewriteNNIssueTest.class, RewritePushdownSumBinaryMult.class, RewritePushdownSumOnBinaryTest.class, RewritePushdownUaggTest.class,
