Repository: systemml Updated Branches: refs/heads/master 92ee2cbf8 -> 8b054804e
[SYSTEMML-445] Bugfix in Caffe2DML/Keras2DML's concat layer for sentence CNN Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8b054804 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8b054804 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8b054804 Branch: refs/heads/master Commit: 8b054804e64ba48ca87016dbe82a4349489b031d Parents: 92ee2cb Author: Niketan Pansare <[email protected]> Authored: Thu Feb 8 15:06:19 2018 -0800 Committer: Niketan Pansare <[email protected]> Committed: Thu Feb 8 15:06:19 2018 -0800 ---------------------------------------------------------------------- src/main/python/systemml/mllearn/keras2caffe.py | 6 ++---- src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala | 10 ++++++---- 2 files changed, 8 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/8b054804/src/main/python/systemml/mllearn/keras2caffe.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mllearn/keras2caffe.py b/src/main/python/systemml/mllearn/keras2caffe.py index 89ec5e4..ac3ba80 100755 --- a/src/main/python/systemml/mllearn/keras2caffe.py +++ b/src/main/python/systemml/mllearn/keras2caffe.py @@ -76,10 +76,8 @@ def _getInboundLayers(layer): for node in inbound_nodes: node_list = node.inbound_layers # get layers pointing to this node in_names = in_names + node_list - if any('flat' in s.name for s in in_names): # For Caffe2DML to reroute any use of Flatten layers - return _getInboundLayers([s for s in in_names if 'flat' in s.name][0]) - return in_names - + # For Caffe2DML to reroute any use of Flatten layers + return list(chain.from_iterable( [ _getInboundLayers(l) if isinstance(l, keras.layers.Flatten) else [ l ] for l in in_names ] )) def _getCompensatedAxis(layer): compensated_axis = layer.axis http://git-wip-us.apache.org/repos/asf/systemml/blob/8b054804/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala index 2c81eda..9b4736a 100644 --- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala +++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala @@ -439,6 +439,8 @@ class Concat(val param: LayerParameter, val id: Int, val net: CaffeNetwork) exte // This is useful because we do not support multi-input cbind and rbind in DML. def _getMultiFn(fn: String): String = { if (_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList + if(_childLayers.length < 2) + throw new DMLRuntimeException("Incorrect usage of Concat layer. Expected atleast 2 bottom layers, but found " + _childLayers.length) var tmp = fn + "(" + _childLayers(0).out + ", " + _childLayers(1).out + ")" for (i <- 2 until _childLayers.size) { tmp = fn + "(" + tmp + ", " + _childLayers(i).out + ")" @@ -492,20 +494,20 @@ class Concat(val param: LayerParameter, val id: Int, val net: CaffeNetwork) exte else " = " + dOutVar + "[," + indexString + " ]; " // concat_start_index = concat_end_index + 1 - // concat_end_index = concat_start_index + $$ - 1 + // concat_end_index = concat_start_index + ## - 1 val initializeIndexString = "concat_start_index" + outSuffix + " = concat_end_index" + outSuffix + " + 1; concat_end_index" + outSuffix + - " = concat_start_index" + outSuffix + " + $$ - 1; " + " = concat_start_index" + outSuffix + " + ## - 1; " if (param.getConcatParam.getAxis == 0) { bottomLayers.map(l => { dmlScript - .append(initializeIndexString.replaceAll("$$", nrow(l.out))) + .append(initializeIndexString.replaceAll("##", nrow(l.out))) // X1 = Z[concat_start_index:concat_end_index,] .append(dX(l.id) + outSuffix + doutVarAssignment) }) } else { bottomLayers.map(l => { dmlScript - .append(initializeIndexString.replaceAll("$$", int_mult(l.outputShape._1, l.outputShape._2, l.outputShape._3))) + .append(initializeIndexString.replaceAll("##", int_mult(l.outputShape._1, l.outputShape._2, l.outputShape._3))) // X1 = Z[concat_start_index:concat_end_index,] .append(dX(l.id) + outSuffix + doutVarAssignment) })
