Check and fix cudnn engine for concat and slice layer
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/84811118 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/84811118 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/84811118 Branch: refs/heads/master Commit: 848111181f7c3d6844c53461c0f9dfd43db47b13 Parents: d1110c0 Author: RUAN0007 <[email protected]> Authored: Tue Nov 29 10:46:39 2016 +0800 Committer: RUAN0007 <[email protected]> Committed: Tue Nov 29 10:46:39 2016 +0800 ---------------------------------------------------------------------- python/singa/layer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/84811118/python/singa/layer.py ---------------------------------------------------------------------- diff --git a/python/singa/layer.py b/python/singa/layer.py index 0244454..95b78c9 100644 --- a/python/singa/layer.py +++ b/python/singa/layer.py @@ -814,7 +814,10 @@ class Concat(Layer): self.in_shapes = input_sample_shapes self.axis = axis self.conf.concat_conf.axis = axis - self.layer = _create_layer(engine, 'Concat') + if engine == "cudnn": + self.layer = _create_layer('singacuda', 'Concat') + else: + self.layer = _create_layer(engine, 'Concat') if input_sample_shapes is not None: self.setup(input_sample_shapes) @@ -836,7 +839,10 @@ class Slice(Layer): self.axis = axis self.conf.slice_conf.axis = axis self.conf.slice_conf.slice_point.extend(slice_point) - self.layer = _create_layer(engine, 'Slice') + if engine == "cudnn": + self.layer = _create_layer('singacuda', 'Slice') + else: + self.layer = _create_layer(engine, 'Slice') if input_sample_shape is not None: self.setup(input_sample_shape)
