Repository: systemml Updated Branches: refs/heads/master ca04d7cdd -> 9d8fc723c
[SYSTEMML-1679] Add a new threshold utility function This function accepts a matrix X and a threshold parameter thresh to get an indicator matrix with values in {0, 1} depending on whether or not the values in X are above thresh. It can be used, for example, to determine the predicted class in a binary classification problem given the output of a sigmoid layer. Closes #548. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9d8fc723 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9d8fc723 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9d8fc723 Branch: refs/heads/master Commit: 9d8fc723cdaad5d47692ba0b04e566b2a7d9b1bc Parents: ca04d7c Author: Fei Hu <hufe...@gmail.com> Authored: Fri Jun 16 16:16:34 2017 -0700 Committer: Mike Dusenberry <mwdus...@us.ibm.com> Committed: Fri Jun 16 16:16:34 2017 -0700 ---------------------------------------------------------------------- scripts/nn/test/README.md | 4 ++-- scripts/nn/test/run_tests.dml | 1 + scripts/nn/test/test.dml | 22 ++++++++++++++++++++++ scripts/nn/util.dml | 16 ++++++++++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/README.md ---------------------------------------------------------------------- diff --git a/scripts/nn/test/README.md b/scripts/nn/test/README.md index b714d50..0143752 100644 --- a/scripts/nn/test/README.md +++ b/scripts/nn/test/README.md @@ -26,7 +26,7 @@ limitations under the License. #### All layers are tested for correct derivatives ("gradient-checking"), and many layers also have correctness tests against simpler reference implementations. * `grad_check.dml` - Contains gradient-checks for all layers as individual DML functions. * `test.dml` - Contains correctness tests for several of the more complicated layers by checking against simple reference implementations, such as `conv_simple.dml`. All tests are formulated as individual DML functions. -* `tests.dml` - A DML script that runs all of the tests in `grad_check.dml` and `test.dml`. +* `run_tests.dml` - A DML script that runs all of the tests in `grad_check.dml` and `test.dml`. ## Execution -* `spark-submit SystemML.jar -f nn/test/tests.dml` from the base of the project. +* `spark-submit SystemML.jar -f nn/test/run_tests.dml` from the base of the project. http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/run_tests.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml index b48606c..c9b1b3e 100644 --- a/scripts/nn/test/run_tests.dml +++ b/scripts/nn/test/run_tests.dml @@ -92,6 +92,7 @@ test::im2col() test::max_pool2d() test::padding() test::tanh() +test::threshold() print("---") print("Other tests complete -- look for any ERRORs or WARNINGs.") http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/test.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml index 52fb063..cfb8c79 100644 --- a/scripts/nn/test/test.dml +++ b/scripts/nn/test/test.dml @@ -605,3 +605,25 @@ tanh = function() { } } +threshold = function() { + /* + * Test for threshold function. + */ + print("Testing the threshold function.") + + # Generate data + X = matrix("0.31 0.24 0.87 + 0.45 0.66 0.65 + 0.24 0.91 0.13", rows=3, cols=3) + thresh = 0.5 + target_matrix = matrix("0.0 0.0 1.0 + 0.0 1.0 1.0 + 0.0 1.0 0.0", rows=3, cols=3) + + # Get the indicator matrix + indicator_matrix = util::threshold(X, thresh) + + # Equivalency check + out = test_util::check_all_equal(indicator_matrix, target_matrix) +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/util.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml index 3a73f08..c4da16a 100644 --- a/scripts/nn/util.dml +++ b/scripts/nn/util.dml @@ -200,3 +200,19 @@ unpad_image = function(matrix[double] img_padded, int Hin, int Win, int padh, in } } +threshold = function(matrix[double] X, double thresh) + return (matrix[double] out) { + /* + * Computes an indicator matrix with values in {0, 1} depending on + * whether or not the values in X are above the input threshold + * + * Inputs: + * - X: Inputs, of shape (any, any). + * - thresh: Input threshold. + * + * Outputs: + * - out: Outputs, of same shape as X. + */ + out = X > thresh +} +