You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2017/06/16 23:17:56 UTC
systemml git commit: [SYSTEMML-1679] Add a new threshold utility
function
Repository: systemml
Updated Branches:
refs/heads/master ca04d7cdd -> 9d8fc723c
[SYSTEMML-1679] Add a new threshold utility function
This function accepts a matrix X and a threshold parameter thresh to get an indicator matrix with
values in {0, 1} depending on whether or not the values in X are above thresh. It can be used, for
example, to determine the predicted class in a binary classification problem given the output of
a sigmoid layer.
Closes #548.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9d8fc723
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9d8fc723
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9d8fc723
Branch: refs/heads/master
Commit: 9d8fc723cdaad5d47692ba0b04e566b2a7d9b1bc
Parents: ca04d7c
Author: Fei Hu <hu...@gmail.com>
Authored: Fri Jun 16 16:16:34 2017 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Fri Jun 16 16:16:34 2017 -0700
----------------------------------------------------------------------
scripts/nn/test/README.md | 4 ++--
scripts/nn/test/run_tests.dml | 1 +
scripts/nn/test/test.dml | 22 ++++++++++++++++++++++
scripts/nn/util.dml | 16 ++++++++++++++++
4 files changed, 41 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/README.md
----------------------------------------------------------------------
diff --git a/scripts/nn/test/README.md b/scripts/nn/test/README.md
index b714d50..0143752 100644
--- a/scripts/nn/test/README.md
+++ b/scripts/nn/test/README.md
@@ -26,7 +26,7 @@ limitations under the License.
#### All layers are tested for correct derivatives ("gradient-checking"), and many layers also have correctness tests against simpler reference implementations.
* `grad_check.dml` - Contains gradient-checks for all layers as individual DML functions.
* `test.dml` - Contains correctness tests for several of the more complicated layers by checking against simple reference implementations, such as `conv_simple.dml`. All tests are formulated as individual DML functions.
-* `tests.dml` - A DML script that runs all of the tests in `grad_check.dml` and `test.dml`.
+* `run_tests.dml` - A DML script that runs all of the tests in `grad_check.dml` and `test.dml`.
## Execution
-* `spark-submit SystemML.jar -f nn/test/tests.dml` from the base of the project.
+* `spark-submit SystemML.jar -f nn/test/run_tests.dml` from the base of the project.
http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index b48606c..c9b1b3e 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -92,6 +92,7 @@ test::im2col()
test::max_pool2d()
test::padding()
test::tanh()
+test::threshold()
print("---")
print("Other tests complete -- look for any ERRORs or WARNINGs.")
http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index 52fb063..cfb8c79 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -605,3 +605,25 @@ tanh = function() {
}
}
+threshold = function() {
+ /*
+ * Test for threshold function.
+ */
+ print("Testing the threshold function.")
+
+ # Generate data
+ X = matrix("0.31 0.24 0.87
+ 0.45 0.66 0.65
+ 0.24 0.91 0.13", rows=3, cols=3)
+ thresh = 0.5
+ target_matrix = matrix("0.0 0.0 1.0
+ 0.0 1.0 1.0
+ 0.0 1.0 0.0", rows=3, cols=3)
+
+ # Get the indicator matrix
+ indicator_matrix = util::threshold(X, thresh)
+
+ # Equivalency check
+ out = test_util::check_all_equal(indicator_matrix, target_matrix)
+}
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index 3a73f08..c4da16a 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -200,3 +200,19 @@ unpad_image = function(matrix[double] img_padded, int Hin, int Win, int padh, in
}
}
+threshold = function(matrix[double] X, double thresh)
+ return (matrix[double] out) {
+ /*
+ * Computes an indicator matrix with values in {0, 1} depending on
+ * whether or not the values in X are above the input threshold
+ *
+ * Inputs:
+ * - X: Inputs, of shape (any, any).
+ * - thresh: Input threshold.
+ *
+ * Outputs:
+ * - out: Outputs, of same shape as X.
+ */
+ out = X > thresh
+}
+