convert param for both v3 and v4
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3f39cfa2 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3f39cfa2 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3f39cfa2 Branch: refs/heads/master Commit: 3f39cfa246b477329ce349b8d2bcf10e5353bf37 Parents: 40124db Author: wangwei <[email protected]> Authored: Fri Jul 7 13:20:03 2017 +0800 Committer: wangwei <[email protected]> Committed: Fri Jul 7 13:20:03 2017 +0800 ---------------------------------------------------------------------- examples/imagenet/inception/inception_v3.py | 456 +++++++++++++---------- examples/imagenet/inception/inception_v4.py | 87 ++--- 2 files changed, 305 insertions(+), 238 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3f39cfa2/examples/imagenet/inception/inception_v3.py ---------------------------------------------------------------------- diff --git a/examples/imagenet/inception/inception_v3.py b/examples/imagenet/inception/inception_v3.py index dc7b24b..421343b 100644 --- a/examples/imagenet/inception/inception_v3.py +++ b/examples/imagenet/inception/inception_v3.py @@ -28,52 +28,63 @@ from __future__ import division from __future__ import print_function from singa.layer import Conv2D, Activation, MaxPooling2D, AvgPooling2D,\ - Split, Concat, Dropout, Flatten, Dense, BatchNormalization + Split, Concat, Dropout, Flatten, BatchNormalization from singa import net as ffnet ffnet.verbose = True + def conv2d(net, name, nb_filter, k, s=1, border_mode='SAME', src=None): if type(k) is list: k = (k[0], k[1]) - net.add(Conv2D(name, nb_filter, k, s, border_mode=border_mode, use_bias=False), src) + net.add(Conv2D(name, nb_filter, k, s, border_mode=border_mode, + use_bias=False), src) net.add(BatchNormalization('%s/BatchNorm' % name)) return net.add(Activation(name+'/relu')) -def inception_v3_base(name, sample_shape, final_endpoint='Mixed_6e', aux_name=None, depth_multiplier=1, min_depth=16): +def inception_v3_base(name, sample_shape, final_endpoint, aux_endpoint, + depth_multiplier=1, min_depth=16): """Creates the Inception V3 network up to the given final endpoint. Args: - inputs: a 4-D tensor of size [batch_size, height, width, 3]. + sample_shape: input image sample shape, 3d tuple final_endpoint: specifies the endpoint to construct the network up to. + aux_endpoint: for aux loss. Returns: logits: the logits outputs of the model. end_points: the set of end_points from the inception model. Raises: - ValueError: if final_endpoint is not set to one of the predefined values, + ValueError: if final_endpoint is not set to one of the predefined values """ - endpoints = {} - def final_aux_check(block_name, net): + V3 = 'InceptionV3' + end_points = {} + net = ffnet.FeedForwardNet() + + def final_aux_check(block_name): if block_name == final_endpoint: - return net, endpoints[block_name], endpoints - if block_name == aux_name: - endpoints[aux_name + '-aux'] = net.add(Split('%s-aux' % aux_name, 2)) + return True + if block_name == aux_endpoint: + aux = aux_endpoint + '-aux' + end_points[aux] = net.add(Split(aux, 2)) + return False - net = ffnet.FeedForwardNet() - depth = lambda d: max(int(d * depth_multiplier), min_depth) - V3 = 'InceptionV3' + def depth(d): + return max(int(d * depth_multiplier), min_depth) - name = V3 + '/Conv2d_1a_3x3' + blk = V3 + '/Conv2d_1a_3x3' # 299 x 299 x 3 - net.add(Conv2D(name, depth(32), 3, 2, border_mode='VALID', use_bias=False, input_sample_shape=sample_shape)) - net.add(BatchNormalization(name + '/BatchNorm')) - net.add(Activation(name + '/relu')) - # 149 x 149 x 32 + net.add(Conv2D(blk, depth(32), 3, 2, border_mode='VALID', use_bias=False, + input_sample_shape=sample_shape)) + net.add(BatchNormalization(blk + '/BatchNorm')) + end_points[blk] = net.add(Activation(blk + '/relu')) + if final_aux_check(blk): + return net, end_points + # 149 x 149 x 32 conv2d(net, '%s/Conv2d_2a_3x3' % V3, depth(32), 3, border_mode='VALID') # 147 x 147 x 32 conv2d(net, '%s/Conv2d_2b_3x3' % V3, depth(64), 3) @@ -85,221 +96,274 @@ def inception_v3_base(name, sample_shape, final_endpoint='Mixed_6e', aux_name=No conv2d(net, '%s/Conv2d_4a_3x3' % V3, depth(192), 3, border_mode='VALID') # 71 x 71 x 192. net.add(MaxPooling2D('%s/MaxPool_5a_3x3' % V3, 3, 2, border_mode='VALID')) - # 35 x 35 x 192. + # 35 x 35 x 192. + blk = V3 + '/Mixed_5b' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(48), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % blk, depth(64), 5) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3) + net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(32), 1) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + if final_aux_check(blk): + return net, end_points - m5b = V3 + '/Mixed_5b' - s = net.add(Split('%s/Split' % m5b, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m5b, depth(64), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m5b, depth(48), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % m5b, depth(64), 5) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m5b, depth(64), 1, src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m5b, depth(96), 3) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % m5b, depth(96), 3) - net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m5b, 3, 1), s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m5b, depth(32), 1) - endpoints[m5b] =net.add(Concat('%s/Concat' % m5b, 1), [br0, br1, br2, br3]) - final_aux_check(m5b, net) # mixed_1: 35 x 35 x 288. - m5c = V3 + '/Mixed_5c' - s = net.add(Split('%s/Split' % m5c, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m5c, depth(64), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Con2d_0b_1x1' % m5c, depth(48), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Conv_1_0c_5x5' % m5c, depth(64), 5) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m5c, depth(64), 1, src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m5c, depth(96), 3) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % m5c, depth(96), 3) - br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m5c, 3, 1), src=s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m5c, depth(64), 1) - endpoints[m5c] = net.add(Concat('%s/Concat' % m5c, 1), [br0, br1, br2, br3]) - final_aux_check(m5c, net) + blk = V3 + '/Mixed_5c' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x1' % blk, depth(48), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv_1_0c_5x5' % blk, depth(64), 5) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3) + br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(64), 1) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + if final_aux_check(blk): + return net, end_points # mixed_2: 35 x 35 x 288. - m5d = V3 + '/Mixed_5d' - s = net.add(Split('%s/Split' % m5d, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m5d, depth(64), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m5d, depth(48), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % m5d, depth(64), 5) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m5d, depth(64), 1, src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m5d, depth(96), 3) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % m5d, depth(96), 3) - br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m5d, 3, 1), s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m5d, depth(64), 1) - endpoints[m5d] =net.add(Concat('%s/Concat' % m5d, 1), [br0, br1, br2, br3]) - final_aux_check(m5d, net) + blk = V3 + '/Mixed_5d' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(48), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % blk, depth(64), 5) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3) + br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(64), 1) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + if final_aux_check(blk): + return net, end_points # mixed_3: 17 x 17 x 768. - m6a = V3 + '/Mixed_6a' - s = net.add(Split('%s/Split' % m6a, 3)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_1x1' % m6a, depth(384), 3, 2, border_mode='VALID', src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6a, depth(64), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x3' % m6a, depth(96), 3) - br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_1x1' % m6a, depth(96), 3, 2, border_mode='VALID') - br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % m6a, 3, 2, border_mode='VALID'), s) - endpoints[m6a] = net.add(Concat('%s/Concat' % m6a, 1), [br0, br1, br2]) - final_aux_check(m6a, net) + blk = V3 + '/Mixed_6a' + s = net.add(Split('%s/Split' % blk, 3)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_1x1' % blk, depth(384), 3, 2, + border_mode='VALID', src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x3' % blk, depth(96), 3) + br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_1x1' % blk, depth(96), 3, 2, + border_mode='VALID') + br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2, + border_mode='VALID'), s) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2]) + if final_aux_check(blk): + return net, end_points # mixed4: 17 x 17 x 768. - m6b = V3 + '/Mixed_6b' - s = net.add(Split('%s/Split' % m6b, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m6b, depth(192), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6b, depth(128), 1, src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m6b, depth(128), [1, 7]) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m6b, depth(192), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m6b, depth(128), [1, 1], src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % m6b, depth(128), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % m6b, depth(128), [1, 7]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % m6b, depth(128), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % m6b, depth(192), [1, 7]) - br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m6b, 3, 1), s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m6b, depth(192), [1, 1]) - endpoints[m6b] = net.add(Concat('%s/Concat' % m6b, 1), [br0, br1, br2, br3]) - final_aux_check(m6b, net) + blk = V3 + '/Mixed_6b' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(128), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(128), [1, 7]) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(128), [1, 1], + src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(128), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(128), [1, 7]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(128), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7]) + br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + if final_aux_check(blk): + return net, end_points # mixed_5: 17 x 17 x 768. - m6c = V3 + '/Mixed_6c' - s = net.add(Split('%s/Split' % m6c, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m6c, depth(192), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6c, depth(160), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m6c, depth(160), [1, 7]) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m6c, depth(192), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m6c, depth(160), [1, 1], src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % m6c, depth(160), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % m6c, depth(160), [1, 7]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % m6c, depth(160), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % m6c, depth(192), [1, 7]) - br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m6c, 3, 1), s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m6c, depth(192), [1, 1]) - endpoints[m6c] = net.add(Concat('%s/Concat' % m6c, 1), [br0, br1, br2, br3]) - final_aux_check(m6c, net) + blk = V3 + '/Mixed_6c' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1], + src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(160), [1, 1], + src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(160), [1, 7]) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(160), [1, 1], + src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(160), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(160), [1, 7]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(160), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7]) + br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + if final_aux_check(blk): + return net, end_points # mixed_6: 17 x 17 x 768. - m6d = V3 + '/Mixed_6d' - s = net.add(Split('%s/Split' % m6d, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m6d, depth(192), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6d, depth(160), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m6d, depth(160), [1, 7]) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m6d, depth(192), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m6d, depth(160), [1, 1], src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % m6d, depth(160), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % m6d, depth(160), [1, 7]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % m6d, depth(160), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % m6d, depth(192), [1, 7]) - br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m6d, 3, 1), s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m6d, depth(192), [1, 1]) - endpoints[m6d] = net.add(Concat('%s/Concat' % m6d, 1), [br0, br1, br2, br3]) - final_aux_check(m6d, net) - - m6e = V3 + '/Mixed_6e' - s = net.add(Split('%s/Split' % m6e, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m6e, depth(192), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6e, depth(192), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m6e, depth(192), [1, 7]) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m6e, depth(192), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m6e, depth(192), [1, 1], src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % m6e, depth(192), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % m6e, depth(192), [1, 7]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % m6e, depth(192), [7, 1]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % m6e, depth(192), [1, 7]) - br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m6d, 3, 1), s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m6d, depth(192), [1, 1]) - endpoints[m6e] = net.add(Concat('%s/Concat' % m6d, 1), [br0, br1, br2, br3]) - final_aux_check(m6e, net) + blk = V3 + '/Mixed_6d' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1], + src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(160), [1, 1], + src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(160), [1, 7]) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(160), [1, 1], + src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(160), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(160), [1, 7]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(160), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7]) + br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + if final_aux_check(blk): + return net, end_points + + blk = V3 + '/Mixed_6e' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1], + src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(192), [1, 1], + src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(192), [1, 7]) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(192), [1, 1], + src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(192), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(192), [1, 7]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(192), [7, 1]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7]) + br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + if final_aux_check(blk): + return net, end_points # mixed_8: 8 x 8 x 1280. - m7a = V3 + '/Mixed_7a' - s = net.add(Split('%s/Split' % m7a, 3)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m7a, depth(192), [1, 1], src=s) - br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % m7a, depth(320), [3, 3], 2, border_mode='VALID') - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m7a, depth(192), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m7a, depth(192), [1, 7]) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m7a, depth(192), [7, 1]) - br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % m7a, depth(192), [3, 3], 2, border_mode='VALID') - br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % m7a, 3, 2, border_mode='VALID'), s) - endpoints[m7a] = net.add(Concat('%s/Concat' % m7a, 1), [br0, br1, br2]) - final_aux_check(m7a, net) + blk = V3 + '/Mixed_7a' + s = net.add(Split('%s/Split' % blk, 3)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1], + src=s) + br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, depth(320), [3, 3], 2, + border_mode='VALID') + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(192), [1, 1], + src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(192), [1, 7]) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) + br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, depth(192), [3, 3], 2, + border_mode='VALID') + br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2, + border_mode='VALID'), s) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2]) + if final_aux_check(blk): + return net, end_points # mixed_9: 8 x 8 x 2048. - m7b = V3 + '/Mixed_7b' - s = net.add(Split('%s/Split' % m7b, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m7b, depth(320), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m7b, depth(384), [1, 1], src=s) - s1 = net.add(Split('%s/Branch_1/Split1' % m7b, 2)) - br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % m7b, depth(384), [1, 3], src=s1) - br12 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x1' % m7b, depth(384), [3, 1], src=s1) - br1 = net.add(Concat('%s/Branch_1/Concat1' % m7b, 1), [br11, br12]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m7b, depth(448), [1, 1], src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m7b, depth(384), [3, 3]) - s2 = net.add(Split('%s/Branch_2/Split2' % m7b, 2)) - br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % m7b, depth(384), [1, 3], src=s2) - br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % m7b, depth(384), [3, 1], src=s2) - br2 = net.add(Concat('%s/Branch_2/Concat2' % m7b, 1), [br21, br22]) - br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m7b, 3, 1), src=s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m7b, depth(192), [1, 1]) - endpoints[m7b] = net.add(Concat('%s/Concat' % m7b, 1), [br0, br1, br2, br3]) - final_aux_check(m7b, net) + blk = V3 + '/Mixed_7b' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(320), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(384), 1, src=s) + s1 = net.add(Split('%s/Branch_1/Split1' % blk, 2)) + br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % blk, depth(384), [1, 3], + src=s1) + br12 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x1' % blk, depth(384), [3, 1], + src=s1) + br1 = net.add(Concat('%s/Branch_1/Concat1' % blk, 1), [br11, br12]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(448), 1, src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(384), 3) + s2 = net.add(Split('%s/Branch_2/Split2' % blk, 2)) + br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, depth(384), [1, 3], + src=s2) + br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % blk, depth(384), [3, 1], + src=s2) + br2 = net.add(Concat('%s/Branch_2/Concat2' % blk, 1), [br21, br22]) + br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + if final_aux_check(blk): + return net, end_points # mixed_10: 8 x 8 x 2048. - m7c = V3 + '/Mixed_7c' - s = net.add(Split('%s/Split' % m7c, 4)) - br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m7c, depth(320), [1, 1], src=s) - br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m7c, depth(384), [1, 1], src=s) - s1 = net.add(Split('%s/Branch_1/Split1' % m7c, 2)) - br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % m7c, depth(384), [1, 3], src=s1) - br12 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x1' % m7c, depth(384), [3, 1], src=s1) - br1 = net.add(Concat('%s/Branch_1/Concat1' % m7c, 1), [br11, br12]) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m7c, depth(448), [1, 1], src=s) - br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m7c, depth(384), [3, 3]) - s2 = net.add(Split('%s/Branch_2/Split2' % m7c, 2)) - br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % m7c, depth(384), [1, 3], src=s2) - br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % m7c, depth(384), [3, 1], src=s2) - br2 = net.add(Concat('%s/Branch_2/Concat2' % m7c, 1), [br21, br22]) - br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m7c, 3, 1), src=s) - br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m7c, depth(192), [1, 1]) - endpoints[m7c] = net.add(Concat('%s/Concat' % m7c, 1), [br0, br1, br2, br3]) - final_aux_check(m7c, net) - return net, endpoints[m7c], endpoints - - -def create_net(num_classes=1001, sample_shape=(3, 299, 299), is_training=True, dropout_keep_prob=0.8, create_aux_logits=True): + blk = V3 + '/Mixed_7c' + s = net.add(Split('%s/Split' % blk, 4)) + br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(320), 1, src=s) + br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(384), 1, src=s) + s1 = net.add(Split('%s/Branch_1/Split1' % blk, 2)) + br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % blk, depth(384), [1, 3], + src=s1) + br12 = conv2d(net, '%s/Branch_1/Conv2d_0c_3x1' % blk, depth(384), [3, 1], + src=s1) + br1 = net.add(Concat('%s/Branch_1/Concat1' % blk, 1), [br11, br12]) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(448), [1, 1], + src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(384), [3, 3]) + s2 = net.add(Split('%s/Branch_2/Split2' % blk, 2)) + br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, depth(384), [1, 3], + src=s2) + br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % blk, depth(384), [3, 1], + src=s2) + br2 = net.add(Concat('%s/Branch_2/Concat2' % blk, 1), [br21, br22]) + br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s) + br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) + end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), + [br0, br1, br2, br3]) + assert final_endpoint == blk, \ + 'final_enpoint = %s is not in the net' % final_endpoint + return net, end_points + + +def create_net(num_classes=1001, sample_shape=(3, 299, 299), + final_endpoint='InceptionV3/Mixed_7c', + aux_endpoint='InceptionV3/Mixed_6e', + dropout_keep_prob=0.8): """Creates the Inception V4 model. Args: num_classes: number of predicted classes. - is_training: whether is training or not. dropout_keep_prob: float, the fraction to keep before final layer. - reuse: whether or not the network and its variables should be reused. To be - able to reuse 'scope' must be given. - create_aux_logits: Whether to include the auxiliary logits. + final_endpoint: 'InceptionV3/Mixed_7d', + aux_endpoint: Returns: logits: the logits outputs of the model. end_points: the set of end_points from the inception model. """ name = 'InceptionV3' - if is_training and create_aux_logits: - aux_name = name + '/Mixed_6e' - else: - aux_name = None - net, last_layer, end_points = inception_v3_base(name, sample_shape, aux_name=aux_name) + net, end_points = inception_v3_base(name, sample_shape, final_endpoint, + aux_endpoint) # Auxiliary Head logits - if aux_name is not None: + if aux_endpoint is not None: # 8 x 8 x 1280 - aux_logits = end_points[aux_name + '-aux'] - net.add(AvgPooling2D('%s/AuxLogits/AvgPool_1a_5x5' % name, 5, stride=3, border_mode='VALID'), aux_logits) - t = conv2d(net, '%s/AuxLogits/Conv2d_1b_1x1' % name, 128, 1) - conv2d(net, '%s/AuxLogits/Conv2d_2a' % name, 768, t.get_output_sample_shape()[1:3], border_mode='VALID') - net.add(Flatten('%s/AuxLogits/flat' % name)) - end_points['AuxLogits'] = net.add(Dense('%s/AuxLogits/Aux_logits' % name, num_classes)) + aux_logits = end_points[aux_endpoint + '-aux'] + blk = name + '/AuxLogits' + net.add(AvgPooling2D('%s/AvgPool_1a_5x5' % blk, 5, stride=3, + border_mode='VALID'), aux_logits) + t = conv2d(net, '%s/Conv2d_1b_1x1' % blk, 128, 1) + s = t.get_output_sample_shape()[1:3] + conv2d(net, '%s/Conv2d_2a_%dx%d' % (blk, s[0], s[1]), 768, s, + border_mode='VALID') + net.add(Conv2D('%s/Conv2d_2b_1x1' % blk, num_classes, 1)) + net.add(Flatten('%s/flat' % blk)) # Final pooling and prediction # 8 x 8 x 2048 - net.add(AvgPooling2D('%s/Logits/AvgPool_1a' % name, last_layer.get_output_sample_shape()[1:3], 1, border_mode='VALID'), last_layer) + blk = name + '/Logits' + last_layer = end_points[final_endpoint] + net.add(AvgPooling2D('%s/AvgPool_1a' % blk, + last_layer.get_output_sample_shape()[1:3], 1, + border_mode='VALID'), last_layer) # 1 x 1 x 2048 - net.add(Dropout('%s/Logits/Dropout_1b' % name, 1 - dropout_keep_prob)) - net.add(Flatten('%s/Logits/PreLogitsFlatten' % name)) + net.add(Dropout('%s/Dropout_1b' % blk, 1 - dropout_keep_prob)) + net.add(Conv2D('%s/Conv2d_1c_1x1' % blk, num_classes, 1)) + end_points[blk] = net.add(Flatten('%s/flat' % blk)) # 2048 - end_points['Logits'] = net.add(Dense('%s/Logits/Logits' % name, num_classes)) return net, end_points http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3f39cfa2/examples/imagenet/inception/inception_v4.py ---------------------------------------------------------------------- diff --git a/examples/imagenet/inception/inception_v4.py b/examples/imagenet/inception/inception_v4.py index b641706..9c5883f 100644 --- a/examples/imagenet/inception/inception_v4.py +++ b/examples/imagenet/inception/inception_v4.py @@ -39,9 +39,10 @@ from singa import net as ffnet ffnet.verbose = True + def conv2d(net, name, nb_filter, k, s=1, border_mode='SAME', src=None): net.add(Conv2D(name, nb_filter, k, s, border_mode=border_mode, - use_bias=False), src) + use_bias=False), src) net.add(BatchNormalization('%s/BatchNorm' % name)) return net.add(Activation(name+'/relu')) @@ -66,13 +67,13 @@ def block_reduction_a(blk, net): # By default use stride=1 and SAME padding s = net.add(Split('%s/Split' % blk, 3)) br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 384, 3, 2, - border_mode='VALID', src=s) + border_mode='VALID', src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, 192, 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x3' % blk, 224, 3) br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, 256, 3, 2, - border_mode='VALID') + border_mode='VALID') br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2, - border_mode='VALID'), s) + border_mode='VALID'), s) return net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2]) @@ -97,17 +98,17 @@ def block_inception_b(blk, net): def block_reduction_b(blk, net): """Builds Reduction-B block for Inception v4 network.""" # By default use stride=1 and SAME padding - s = net.add(Split('%s/Split' % blk , 3)) + s = net.add(Split('%s/Split' % blk, 3)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, 192, 1, src=s) br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 192, 3, 2, - border_mode='VALID') + border_mode='VALID') br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, 256, 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, 256, (1, 7)) br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, 320, (7, 1)) br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, 320, 3, 2, - border_mode='VALID') + border_mode='VALID') br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2, - border_mode='VALID'), s) + border_mode='VALID'), s) return net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2]) @@ -123,9 +124,9 @@ def block_inception_c(blk, net): br11 = conv2d(net, '%s/Branch_1/Conv2d_0c_3x1' % blk, 256, (3, 1), src=br1) br1 = net.add(Concat('%s/Branch_1/Concat' % blk, 1), [br10, br11]) - br2 =conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, 384, 1, src=s) - br2 =conv2d(net, '%s/Branch_2/Conv2d_0b_3x1' % blk, 448, (3, 1)) - br2 =conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, 512, (1, 3)) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, 384, 1, src=s) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x1' % blk, 448, (3, 1)) + br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, 512, (1, 3)) br2 = net.add(Split('%s/Branch_2/Split' % blk, 2)) br20 = conv2d(net, '%s/Branch_2/Conv2d_0d_1x3' % blk, 256, (1, 3), src=br2) br21 = conv2d(net, '%s/Branch_2/Conv2d_0e_3x1' % blk, 256, (3, 1), src=br2) @@ -136,8 +137,8 @@ def block_inception_c(blk, net): return net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) -def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d', - aux_endpoint='Inception/Mixed_6e'): +def inception_v4_base(sample_shape, final_endpoint='Inception/Mixed_7d', + aux_endpoint='Inception/Mixed_6e'): """Creates the Inception V4 network up to the given final endpoint. Endpoint name list: 'InceptionV4/' + @@ -148,17 +149,18 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d', 'Mixed_7d'] Args: - inputs: a 4-D tensor of size [batch_size, height, width, 3]. + sample_shape: input image sample shape, 3d tuple final_endpoint: specifies the endpoint to construct the network up to. aux_endpoint: for aux loss. Returns: the neural net - the last layer the set of end_points from the inception model. """ + name = 'InceptionV4' end_points = {} net = ffnet.FeedForwardNet() + def final_aux_check(block_name): if block_name == final_endpoint: return True @@ -167,36 +169,34 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d', end_points[aux] = net.add(Split(aux, 2)) return False - # 299 x 299 x 3 - blk = name + 'Conv2d_1a_3x3' + blk = name + '/Conv2d_1a_3x3' net.add(Conv2D(blk, 32, 3, 2, border_mode='VALID', use_bias=False, - input_sample_shape=sample_shape)) + input_sample_shape=sample_shape)) net.add(BatchNormalization('%s/BatchNorm' % blk)) end_points[blk] = net.add(Activation('%s/relu' % blk)) if final_aux_check(blk): return net, end_points # 149 x 149 x 32 - blk = name + 'Conv2d_2a_3x3' - end_points[blk] = conv2d(net, '%s/Conv2d_2a_3x3' % blk, 32, 3, - border_mode='VALID') + blk = name + '/Conv2d_2a_3x3' + end_points[blk] = conv2d(net, blk, 32, 3, border_mode='VALID') if final_aux_check(blk): return net, end_points # 147 x 147 x 32 - blk = name + 'Conv2d_2b_3x3' - end_points[blk] = conv2d(net, '%s/Conv2d_2b_3x3' % blk, 64, 3) + blk = name + '/Conv2d_2b_3x3' + end_points[blk] = conv2d(net, blk, 64, 3) if final_aux_check(blk): return net, end_points # 147 x 147 x 64 blk = name + '/Mixed_3a' s = net.add(Split('%s/Split' % blk, 2)) - br0 = net.add(MaxPooling2D('%s/Branch_0/MaxPool_0a_3x3' % blk, 3, - 2, border_mode='VALID'), s) + br0 = net.add(MaxPooling2D('%s/Branch_0/MaxPool_0a_3x3' % blk, 3, 2, + border_mode='VALID'), s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_3x3' % blk, 96, 3, 2, - border_mode='VALID', src=s) + border_mode='VALID', src=s) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1]) if final_aux_check(blk): return net, end_points @@ -206,12 +206,12 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d', s = net.add(Split('%s/Split' % blk, 2)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, 64, 1, src=s) br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 96, 3, - border_mode='VALID') + border_mode='VALID') br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, 64, 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, 64, (1, 7)) br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, 64, (7, 1)) br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, 96, 3, - border_mode='VALID') + border_mode='VALID') end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1]) if final_aux_check(blk): return net, end_points @@ -220,9 +220,9 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d', blk = name + '/Mixed_5a' s = net.add(Split('%s/Split' % blk, 2)) br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 192, 3, 2, - border_mode='VALID', src=s) - br1 = net.add(MaxPooling2D('%s/Branch_1/MaxPool_1a_3x3' % blk, 3, - 2, border_mode='VALID'), s) + border_mode='VALID', src=s) + br1 = net.add(MaxPooling2D('%s/Branch_1/MaxPool_1a_3x3' % blk, 3, 2, + border_mode='VALID'), s) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1]) if final_aux_check(blk): return net, end_points @@ -265,12 +265,13 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d', if final_aux_check(blk): return net, end_points - return net, end_points + assert final_endpoint == blk, \ + 'final_enpoint = %s is not in the net' % final_endpoint def create_net(num_classes=1001, sample_shape=(3, 299, 299), is_training=True, - dropout_keep_prob=0.8, final_endpoint='InceptionV4/Mixed_7d', - aux_endpoint='InceptionV4/Mixed_6e'): + dropout_keep_prob=0.8, final_endpoint='InceptionV4/Mixed_7d', + aux_endpoint='InceptionV4/Mixed_6e'): """Creates the Inception V4 model. Args: @@ -285,28 +286,30 @@ def create_net(num_classes=1001, sample_shape=(3, 299, 299), is_training=True, """ end_points = {} name = 'InceptionV4' - net, end_points = inception_v4_base(name, sample_shape, - final_endpoint=final_endpoint, aux_endpoint=aux_endpoint) + net, end_points = inception_v4_base(sample_shape, + final_endpoint=final_endpoint, + aux_endpoint=aux_endpoint) # Auxiliary Head logits if aux_endpoint is not None: # 17 x 17 x 1024 aux_logits = end_points[aux_endpoint + '-aux'] - blk = 'AuxLogits' + blk = name + '/AuxLogits' net.add(AvgPooling2D('%s/AvgPool_1a_5x5' % blk, 5, stride=3, - border_mode='VALID'), aux_logits) + border_mode='VALID'), aux_logits) t = conv2d(net, '%s/Conv2d_1b_1x1' % blk, 128, 1) conv2d(net, '%s/Conv2d_2a' % blk, 768, - t.get_output_sample_shape()[1:3], border_mode='VALID') + t.get_output_sample_shape()[1:3], border_mode='VALID') net.add(Flatten('%s/flat' % blk)) end_points[blk] = net.add(Dense('%s/Aux_logits' % blk, num_classes)) # Final pooling and prediction # 8 x 8 x 1536 - blk = 'Logits' + blk = name + '/Logits' last_layer = end_points[final_endpoint] net.add(AvgPooling2D('%s/AvgPool_1a' % blk, - last_layer.get_output_sample_shape()[1:3], border_mode='VALID'), - last_layer) + last_layer.get_output_sample_shape()[1:3], + border_mode='VALID'), + last_layer) # 1 x 1 x 1536 net.add(Dropout('%s/Dropout_1b' % blk, 1 - dropout_keep_prob)) net.add(Flatten('%s/PreLogitsFlatten' % blk))
