You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/11/29 05:31:11 UTC
[2/2] incubator-singa git commit: Check and fix cudnn engine for
concat and slice layer
Check and fix cudnn engine for concat and slice layer
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/84811118
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/84811118
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/84811118
Branch: refs/heads/master
Commit: 848111181f7c3d6844c53461c0f9dfd43db47b13
Parents: d1110c0
Author: RUAN0007 <ru...@gmail.com>
Authored: Tue Nov 29 10:46:39 2016 +0800
Committer: RUAN0007 <ru...@gmail.com>
Committed: Tue Nov 29 10:46:39 2016 +0800
----------------------------------------------------------------------
python/singa/layer.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/84811118/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/python/singa/layer.py b/python/singa/layer.py
index 0244454..95b78c9 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -814,7 +814,10 @@ class Concat(Layer):
self.in_shapes = input_sample_shapes
self.axis = axis
self.conf.concat_conf.axis = axis
- self.layer = _create_layer(engine, 'Concat')
+ if engine == "cudnn":
+ self.layer = _create_layer('singacuda', 'Concat')
+ else:
+ self.layer = _create_layer(engine, 'Concat')
if input_sample_shapes is not None:
self.setup(input_sample_shapes)
@@ -836,7 +839,10 @@ class Slice(Layer):
self.axis = axis
self.conf.slice_conf.axis = axis
self.conf.slice_conf.slice_point.extend(slice_point)
- self.layer = _create_layer(engine, 'Slice')
+ if engine == "cudnn":
+ self.layer = _create_layer('singacuda', 'Slice')
+ else:
+ self.layer = _create_layer(engine, 'Slice')
if input_sample_shape is not None:
self.setup(input_sample_shape)