You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2017/06/15 22:22:41 UTC

[3/3] systemml git commit: [MINOR] Make use of util::channel_sums function in conv2d_builtin

[MINOR] Make use of util::channel_sums function in conv2d_builtin


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/17838a3d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/17838a3d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/17838a3d

Branch: refs/heads/master
Commit: 17838a3d34d6e3fa860222ff9af2e5c17f6c3a77
Parents: d49ab98
Author: Mike Dusenberry <mw...@us.ibm.com>
Authored: Thu Jun 15 15:20:49 2017 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Thu Jun 15 15:20:49 2017 -0700

----------------------------------------------------------------------
 scripts/nn/layers/conv2d_builtin.dml | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/17838a3d/scripts/nn/layers/conv2d_builtin.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/conv2d_builtin.dml b/scripts/nn/layers/conv2d_builtin.dml
index 6b066eb..141890e 100644
--- a/scripts/nn/layers/conv2d_builtin.dml
+++ b/scripts/nn/layers/conv2d_builtin.dml
@@ -24,6 +24,7 @@
  *
  * This implementation uses a built-in operator for higher performance.
  */
+source("nn/util.dml") as util
 
 forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
                    int C, int Hin, int Win, int Hf, int Wf,
@@ -127,7 +128,7 @@ backward = function(matrix[double] dout, int Hout, int Wout,
                             input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf])
 
   # Partial derivatives for bias vector
-  db = rowSums(matrix(colSums(dout), rows=F, cols=Hout*Wout))
+  db = util::channel_sums(dout, F, Hout, Wout)
 }
 
 init = function(int F, int C, int Hf, int Wf)