You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by gw...@apache.org on 2017/06/22 18:09:25 UTC
systemml git commit: [SYSTEMML-1728] Reshape util to convert tensors
in NCHW to CNHW format
Repository: systemml
Updated Branches:
refs/heads/master f516e4bdc -> 345682404
[SYSTEMML-1728] Reshape util to convert tensors in NCHW to CNHW format
Closes #552.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/34568240
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/34568240
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/34568240
Branch: refs/heads/master
Commit: 345682404c3fb1348484c375e811ee3f5805a691
Parents: f516e4b
Author: prithvirajsen <se...@us.ibm.com>
Authored: Thu Jun 22 11:05:24 2017 -0700
Committer: Glenn Weidner <gw...@us.ibm.com>
Committed: Thu Jun 22 11:05:25 2017 -0700
----------------------------------------------------------------------
scripts/nn/test/run_tests.dml | 1 +
scripts/nn/test/test.dml | 33 ++++++++++++++++++++++++
scripts/nn/util.dml | 52 ++++++++++++++++++++++++++++++++++++++
3 files changed, 86 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index 5f3ca6e..cca0d0d 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -97,6 +97,7 @@ test::max_pool2d()
test::padding()
test::tanh()
test::threshold()
+test::transpose_NCHW_to_CNHW()
print("---")
print("Other tests complete -- look for any ERRORs or WARNINGs.")
http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index 37f9f73..b0899e3 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -488,6 +488,39 @@ padding = function() {
}
}
+transpose_NCHW_to_CNHW = function() {
+ /*
+ * Test for `transpose_NCHW_to_CNHW` function.
+ */
+ print("Testing transpose_NCHW_to_CNHW function.")
+
+ # Generate data
+ N = 2
+ C = 3
+ H = 4
+ W = 5
+ X = matrix(seq(1, N*C*H*W), rows=N, cols=C*H*W)
+
+ out = util::transpose_NCHW_to_CNHW(X, C)
+
+ target =
+ matrix("1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
+ 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
+ 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
+ 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
+ 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
+ 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120",
+ rows=C, cols=N*H*W)
+
+ # Equivalency check
+ for (i in 1:nrow(out)) {
+ for(j in 1:ncol(out)) {
+ rel_error = test_util::check_rel_error(as.scalar(out[i,j]),
+ as.scalar(target[i,j]), 1e-10, 1e-12)
+ }
+ }
+}
+
max_pool2d = function() {
/*
* Test for the 2D max pooling functions.
http://git-wip-us.apache.org/repos/asf/systemml/blob/34568240/scripts/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index c4da16a..329f22f 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -216,3 +216,55 @@ threshold = function(matrix[double] X, double thresh)
out = X > thresh
}
+/*
+ * Reshape util for tensors in NCHW format.
+ * Transposes the 1st and 2nd dimensions.
+ */
+transpose_NCHW_to_CNHW = function(matrix[double] X, int C) return (matrix[double] out){
+ /*
+ * Inputs:
+ * - X: Input with N rows and channels flattened within each row in
+ * channel-major format (NCHW).
+ * - C: Number of channels (dimensionality of depth).
+ *
+ * Outputs:
+ * - out: Transposed output with C rows.
+ */
+ N = nrow(X)
+ D = ncol(X) / C
+
+ /*
+ * This is an easy reshape because the channels remain intact. By
+ * reshaping X to a matrix with N*C rows, we can reduce our task to
+ * re-ordering rows (followed by the obvious reshape to achieve the
+ * required output shape with C rows).
+ *
+ * The difficult part is to obtain the permutation matrix required
+ * for re-ordering the rows. In this case, since we want to bring the
+ * ith channels from all rows together, we will need a column vector
+ * of the following form:
+ * [1, 1+C, 1+2C, ..., 1+(N-1)C,
+ * 2, 2+C, ..., 2+(N-1)C,
+ * 3, 3+C, ..., 3+(N-1)C,
+ * .
+ * .
+ * .
+ * C, 2C, ..., NC]'
+ * This vector can be produced via an outer call.
+ */
+ col_idx = outer(seq(1,C), C*t(seq(0,N-1)), "+")
+
+ /*
+ * Generate the permutation matrix by:
+ * - reshaping the result of outer into a col
+ * - invoking table
+ */
+ permut = table(seq(1, N*C), matrix(col_idx, rows=N*C, cols=1), N*C, N*C)
+
+ /*
+ * Generate the output by:
+ * - pre-multiplying the (reshaped) X with the permutation matrix
+ * - reshape to get the output shape with C rows
+ */
+ out = matrix(permut %*% matrix(X, rows=N*C, cols=D), rows=C, cols=N*D)
+}