You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by gw...@apache.org on 2017/06/22 18:09:25 UTC

systemml git commit: [SYSTEMML-1728] Reshape util to convert tensors in NCHW to CNHW format

Repository: systemml
Updated Branches:
  refs/heads/master f516e4bdc -> 345682404


[SYSTEMML-1728] Reshape util to convert tensors in NCHW to CNHW format

Closes #552.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/34568240
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/34568240
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/34568240

Branch: refs/heads/master
Commit: 345682404c3fb1348484c375e811ee3f5805a691
Parents: f516e4b
Author: prithvirajsen <se...@us.ibm.com>
Authored: Thu Jun 22 11:05:24 2017 -0700
Committer: Glenn Weidner <gw...@us.ibm.com>
Committed: Thu Jun 22 11:05:25 2017 -0700

----------------------------------------------------------------------
 scripts/nn/test/run_tests.dml |  1 +
 scripts/nn/test/test.dml      | 33 ++++++++++++++++++++++++
 scripts/nn/util.dml           | 52 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 86 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index 5f3ca6e..cca0d0d 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -97,6 +97,7 @@ test::max_pool2d()
 test::padding()
 test::tanh()
 test::threshold()
+test::transpose_NCHW_to_CNHW()
 
 print("---")
 print("Other tests complete -- look for any ERRORs or WARNINGs.")

http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index 37f9f73..b0899e3 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -488,6 +488,39 @@ padding = function() {
   }
 }
 
+transpose_NCHW_to_CNHW = function() {
+  /*
+   * Test for `transpose_NCHW_to_CNHW` function.
+   */
+  print("Testing transpose_NCHW_to_CNHW function.")
+  
+  # Generate data
+  N = 2
+  C = 3
+  H = 4
+  W = 5
+  X = matrix(seq(1, N*C*H*W), rows=N, cols=C*H*W)
+  
+  out = util::transpose_NCHW_to_CNHW(X, C)
+  
+  target = 
+    matrix("1   2   3   4   5   6   7   8   9   10  11  12  13  14  15  16  17  18  19  20
+            61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80
+            21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40
+            81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100
+            41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60
+            101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120",
+           rows=C, cols=N*H*W)
+  
+  # Equivalency check
+  for (i in 1:nrow(out)) {
+    for(j in 1:ncol(out)) {
+      rel_error = test_util::check_rel_error(as.scalar(out[i,j]),
+                                             as.scalar(target[i,j]), 1e-10, 1e-12)
+    }
+  }
+}
+
 max_pool2d = function() {
   /*
    * Test for the 2D max pooling functions.

http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index c4da16a..329f22f 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -216,3 +216,55 @@ threshold = function(matrix[double] X, double thresh)
   out = X > thresh
 }
 
+/*
+ * Reshape util for tensors in NCHW format.
+ * Transposes the 1st and 2nd dimensions.
+ */
+transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return (matrix[double] out){
+  /*
+   * Inputs:
+   *  - X: Input with N rows and channels flattened within each row in
+   *      channel-major format (NCHW).
+   *  - C: Number of channels (dimensionality of depth).
+   *
+   * Outputs:
+   *  - out: Transposed output with C rows.
+   */
+  N = nrow(X)
+  D = ncol(X) / C
+
+  /*
+   * This is an easy reshape because the channels remain intact. By
+   * reshaping X to a matrix with N*C rows, we can reduce our task to
+   * re-ordering rows (followed by the obvious reshape to achieve the
+   * required output shape with C rows).
+   *
+   * The difficult part is to obtain the permutation matrix required
+   * for re-ordering the rows. In this case, since we want to bring the
+   * ith channels from all rows together, we will need a column vector
+   * of the following form:
+   * [1, 1+C, 1+2C, ..., 1+(N-1)C,
+   *  2, 2+C, ..., 2+(N-1)C,
+   *  3, 3+C, ..., 3+(N-1)C,
+   *  .
+   *  .
+   *  .
+   *  C, 2C, ..., NC]'
+   * This vector can be produced via an outer call.
+   */
+  col_idx = outer(seq(1,C), C*t(seq(0,N-1)), "+")
+
+  /*
+   * Generate the permutation matrix by:
+   * - reshaping the result of outer into a col
+   * - invoking table
+   */
+  permut = table(seq(1, N*C), matrix(col_idx, rows=N*C, cols=1), N*C, N*C)
+
+  /*
+   * Generate the output by:
+   * - pre-multiplying the (reshaped) X with the permutation matrix
+   * - reshape to get the output shape with C rows
+   */
+  out = matrix(permut %*% matrix(X, rows=N*C, cols=D), rows=C, cols=N*D)
+}