You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2017/06/16 23:17:56 UTC

systemml git commit: [SYSTEMML-1679] Add a new threshold utility function

Repository: systemml
Updated Branches:
  refs/heads/master ca04d7cdd -> 9d8fc723c


[SYSTEMML-1679] Add a new threshold utility function

This function accepts a matrix X and a threshold parameter thresh to get an indicator matrix with
values in {0, 1} depending on whether or not the values in X are above thresh. It can be used, for
example, to determine the predicted class in a binary classification problem given the output of
a sigmoid layer.

Closes #548.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9d8fc723
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9d8fc723
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9d8fc723

Branch: refs/heads/master
Commit: 9d8fc723cdaad5d47692ba0b04e566b2a7d9b1bc
Parents: ca04d7c
Author: Fei Hu <hu...@gmail.com>
Authored: Fri Jun 16 16:16:34 2017 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Fri Jun 16 16:16:34 2017 -0700

----------------------------------------------------------------------
 scripts/nn/test/README.md     |  4 ++--
 scripts/nn/test/run_tests.dml |  1 +
 scripts/nn/test/test.dml      | 22 ++++++++++++++++++++++
 scripts/nn/util.dml           | 16 ++++++++++++++++
 4 files changed, 41 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/README.md
----------------------------------------------------------------------
diff --git a/scripts/nn/test/README.md b/scripts/nn/test/README.md
index b714d50..0143752 100644
--- a/scripts/nn/test/README.md
+++ b/scripts/nn/test/README.md
@@ -26,7 +26,7 @@ limitations under the License.
 #### All layers are tested for correct derivatives ("gradient-checking"), and many layers also have correctness tests against simpler reference implementations.
 * `grad_check.dml` - Contains gradient-checks for all layers as individual DML functions.
 * `test.dml` - Contains correctness tests for several of the more complicated layers by checking against simple reference implementations, such as `conv_simple.dml`.  All tests are formulated as individual DML functions.
-* `tests.dml` - A DML script that runs all of the tests in `grad_check.dml` and `test.dml`.
+* `run_tests.dml` - A DML script that runs all of the tests in `grad_check.dml` and `test.dml`.
 
 ## Execution
-* `spark-submit SystemML.jar -f nn/test/tests.dml` from the base of the project.
+* `spark-submit SystemML.jar -f nn/test/run_tests.dml` from the base of the project.

http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index b48606c..c9b1b3e 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -92,6 +92,7 @@ test::im2col()
 test::max_pool2d()
 test::padding()
 test::tanh()
+test::threshold()
 
 print("---")
 print("Other tests complete -- look for any ERRORs or WARNINGs.")

http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index 52fb063..cfb8c79 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -605,3 +605,25 @@ tanh = function() {
   }
 }
 
+threshold = function() {
+  /*
+   * Test for threshold function.
+   */
+  print("Testing the threshold function.")
+
+  # Generate data
+  X = matrix("0.31 0.24 0.87
+              0.45 0.66 0.65
+              0.24 0.91 0.13", rows=3, cols=3)
+  thresh = 0.5
+  target_matrix = matrix("0.0 0.0 1.0
+                          0.0 1.0 1.0
+                          0.0 1.0 0.0", rows=3, cols=3)
+
+  # Get the indicator matrix
+  indicator_matrix = util::threshold(X, thresh)
+
+  # Equivalency check
+  out = test_util::check_all_equal(indicator_matrix, target_matrix)
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/9d8fc723/scripts/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index 3a73f08..c4da16a 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -200,3 +200,19 @@ unpad_image = function(matrix[double] img_padded, int Hin, int Win, int padh, in
   }
 }
 
+threshold = function(matrix[double] X, double thresh)
+    return (matrix[double] out) {
+  /*
+   * Computes an indicator matrix with values in {0, 1} depending on
+   * whether or not the values in X are above the input threshold
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (any, any).
+   *  - thresh: Input threshold.
+   *
+   * Outputs:
+   *  - out: Outputs, of same shape as X.
+   */
+  out = X > thresh
+}
+