You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2017/06/28 19:25:21 UTC

[2/2] systemml git commit: [SYSTEMML-1736] Add a new 2D top_k utility function

[SYSTEMML-1736] Add a new 2D top_k utility function

This function computes the top k values (i.e. probabilities) and
associated indices (i.e. classes) from the input matrix X. A typical use
case here is that in which X is the output of a 2D softmax layer, so
each channel contains a set of normalized class probalilities, and
values and indices will contain the top k probabilities and indices
along the channel axis. The scenario will be common in an image
segmentation problem.

Closes #551.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5e7e5777
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5e7e5777
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5e7e5777

Branch: refs/heads/master
Commit: 5e7e57774b1936736d8551bacf3dffd60bc45071
Parents: 2e78eb9
Author: Fei Hu <hu...@gmail.com>
Authored: Wed Jun 28 12:23:13 2017 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Wed Jun 28 12:23:13 2017 -0700

----------------------------------------------------------------------
 scripts/nn/test/run_tests.dml |  1 +
 scripts/nn/test/test.dml      | 85 ++++++++++++++++++++++++++++++++++----
 scripts/nn/util.dml           | 62 ++++++++++++++++++++-------
 3 files changed, 127 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/5e7e5777/scripts/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index 4cc2875..f4c33d8 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -100,6 +100,7 @@ test::threshold()
 test::transpose_NCHW_to_CNHW()
 test::top_k_row()
 test::top_k()
+test::top_k2d()
 
 print("---")
 print("Other tests complete -- look for any ERRORs or WARNINGs.")

http://git-wip-us.apache.org/repos/asf/systemml/blob/5e7e5777/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index b1190fc..a1bb6cd 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -493,17 +493,17 @@ transpose_NCHW_to_CNHW = function() {
    * Test for `transpose_NCHW_to_CNHW` function.
    */
   print("Testing transpose_NCHW_to_CNHW function.")
-  
+
   # Generate data
   N = 2
   C = 3
   H = 4
   W = 5
   X = matrix(seq(1, N*C*H*W), rows=N, cols=C*H*W)
-  
+
   out = util::transpose_NCHW_to_CNHW(X, C)
-  
-  target = 
+
+  target =
     matrix("1   2   3   4   5   6   7   8   9   10  11  12  13  14  15  16  17  18  19  20
             61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80
             21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40
@@ -511,7 +511,7 @@ transpose_NCHW_to_CNHW = function() {
             41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60
             101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120",
            rows=C, cols=N*H*W)
-  
+
   # Equivalency check
   for (i in 1:nrow(out)) {
     for(j in 1:ncol(out)) {
@@ -790,7 +790,7 @@ threshold = function() {
 
 top_k_row = function() {
   /*
-   * Test for the top_k function.
+   * Test for the top_k_row function.
    */
   print("Testing the top_k_row function.")
 
@@ -806,7 +806,7 @@ top_k_row = function() {
                              4
                              8", rows=1, cols=3)
 
-  # Test the top 3 for the second row
+  # Test the top 3 for the second row.
   [values, indices] = util::top_k_row(X, 2, 3)
   check_values = test_util::check_all_equal(values, expected_values)
   check_indices = test_util::check_all_equal(indices, expected_indices)
@@ -852,3 +852,74 @@ top_k = function() {
    check_values_topAll = test_util::check_all_equal(values_topAll, expected_values_topAll)
    check_indices_topAll = test_util::check_all_equal(indices_topAll, expected_indices_topAll)
 }
+
+top_k2d = function() {
+  /*
+   * Test for the top_k2d function.
+  */
+  print("Testing the top_k2d function.")
+  # Generate data, of shape (2, 3, 3, 4)
+  k = 2
+  X = matrix("0.1 0.4 0.4 0.5
+              0.4 0.1 0.6 0.1
+              0.7 0.7 0.3 0.2
+
+              0.2 0.5 0.4 0.5
+              0.4 0.1 0.6 0.1
+              0.7 0.8 0.3 0.2
+
+              0.3 0.4 0.4 0.5
+              0.4 0.1 0.6 0.1
+              0.7 0.2 0.3 0.2
+
+              0.1 0.4 0.4 0.5
+              0.4 0.1 0.6 0.1
+              0.7 0.7 0.3 0.2
+
+              0.2 0.5 0.4 0.5
+              0.4 0.1 0.6 0.1
+              0.7 0.8 0.3 0.2
+
+              0.3 0.4 0.4 0.5
+              0.4 0.1 0.6 0.1
+              0.7 0.2 0.3 0.2", rows=2, cols=3*3*4)
+
+  expected_values = matrix("0.3  0.5  0.4  0.5
+                            0.4  0.1  0.6  0.1
+                            0.7  0.8  0.3  0.2
+
+                            0.2  0.4  0.4  0.5
+                            0.4  0.1  0.6  0.1
+                            0.7  0.7  0.3  0.2
+
+                            0.3  0.5  0.4  0.5
+                            0.4  0.1  0.6  0.1
+                            0.7  0.8  0.3  0.2
+
+                            0.2  0.4  0.4  0.5
+                            0.4  0.1  0.6  0.1
+                            0.7  0.7  0.3  0.2", rows=2, cols=2*3*4)
+
+  expected_indices = matrix("3  2  1  1
+                             1  1  1  1
+                             1  2  1  1
+
+                             2  1  2  2
+                             2  2  2  2
+                             2  1  2  2
+
+                             3  2  1  1
+                             1  1  1  1
+                             1  2  1  1
+
+                             2  1  2  2
+                             2  2  2  2
+                             2  1  2  2", rows=2, cols=24)
+
+  [values, indices] = util::top_k2d(X, k, 3, 3, 4)
+
+  # Equivalency check
+  check_values = test_util::check_all_equal(values, expected_values)
+  check_indices = test_util::check_all_equal(indices, expected_indices)
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/5e7e5777/scripts/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index fb54c43..102a507 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -204,7 +204,7 @@ threshold = function(matrix[double] X, double thresh)
     return (matrix[double] out) {
   /*
    * Computes an indicator matrix with values in {0, 1} depending on
-   * whether or not the values in X are above the input threshold
+   * whether or not the values in X are above the input threshold.
    *
    * Inputs:
    *  - X: Inputs, of shape (any, any).
@@ -231,7 +231,6 @@ transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return (matrix[double
    *  - out: Transposed output with C rows.
    */
 
-  if(1==1){}
   N = nrow(X)
   D = ncol(X) / C
 
@@ -275,17 +274,17 @@ top_k_row = function(matrix[double] X, integer r, integer k)
     return (matrix[double] values, matrix[double] indices) {
   /*
    * Computes the top k values (i.e. probabilities) and associated
-   * indices (i.e. classes) in the rth row of the input matrix X
+   * indices (i.e. classes) in the rth row of the input matrix X.
    *
    * Inputs:
-   * - X: Inputs, of shape (N D).
-   * - r: Input row number of X to look for
-   * - k: Input number of top elements to look for
+   *  - X: Inputs, of shape (N, D).
+   *  - r: Input row number of X to look for.
+   *  - k: Input number of top elements to look for.
    *
    * Outputs:
-   * - values: The top k values at the rth row, of shape
-   *    (1, k)
-   * - indices: The class indices, of shape (1, k)
+   *  - values: The top k values at the rth row, of shape
+   *    (1, k).
+   *  - indices: The class indices, of shape (1, k).
    */
 
   #TODO: do r & k need to be checked in the valid range
@@ -308,13 +307,13 @@ top_k = function(matrix[double] X, integer k)
    * indices (i.e. classes) for the input matrix X.
    *
    * Inputs:
-   * - X: Inputs, of shape (N D).
-   * - k: Input number of top elements to look for
+   *  - X: Inputs, of shape (N, D).
+   *  - k: Input number of top elements to look for.
    *
    * Outputs:
-   * - values: The top k values along a certain dimension, of shape
-   *    (N, k)
-   * - indices: The indices of classes, of shape (N, K)
+   *  - values: The top k values along a certain dimension, of shape
+   *    (N, k).
+   *  - indices: The indices of classes, of shape (N, K).
    */
   N = nrow(X)
   D = ncol(X)
@@ -327,3 +326,38 @@ top_k = function(matrix[double] X, integer k)
     indices[r, ] = index
   }
 }
+
+top_k2d = function(matrix[double] X, int k, int C, int Hin, int Win)
+     return (matrix[double] values, matrix[double] indices) {
+  /*
+   * Computes the top k values (i.e. probabilities) and associated
+   * indices (i.e. classes) for the input matrix X.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - k: Input number of top elements to look for.
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *
+   * Outputs:
+   *  - values: The top k values along a certain dimension, of shape
+   *    (N, k*Hin*Win).
+   *  - indices: The indices of classes, of shape (N, k*Hin*Win).
+   */
+  N = nrow(X)
+
+  # Reshape the input matrix (N, C*Hin*Win) to (N*Hin*Win, C)
+  X_C_NHW = transpose_NCHW_to_CNHW(X, C)
+  X_NHW_C = t(X_C_NHW)
+
+  # Compute the top k for the reshape matrix.
+  [values_NHW_K, indices_NHW_K] = top_k(X_NHW_C, k)  # shape: (N*Hin*Win, k)
+
+  values_K_NHW = t(values_NHW_K)
+  indices_K_NHW = t(indices_NHW_K)
+
+  values =  transpose_NCHW_to_CNHW(values_K_NHW, N)
+  indices = transpose_NCHW_to_CNHW(indices_K_NHW, N)
+}
+