szha commented on a change in pull request #9614: MobileNetV2
URL: https://github.com/apache/incubator-mxnet/pull/9614#discussion_r164858934
##########
File path: python/mxnet/gluon/model_zoo/vision/mobilenet.py
##########
@@ -158,3 +166,158 @@ def mobilenet0_25(**kwargs):
The context in which to load the pretrained weights.
"""
return get_mobilenet(0.25, **kwargs)
+
+
+# Block helper
+class BottleNeck(nn.HybridBlock):
+ r"""BottleNeck used in MobileNetV2 model from the
+ `"Inverted Residuals and Linear Bottlenecks: Mobile Networks for
Classification, Detection and Segmentation"
+ <https://arxiv.org/abs/1801.04381>`_ paper.
+
+ Parameters
+ ----------
+ c_in : int
+ Number of input channels.
+ c_out : int
+ Number of output channels.
+ t : int
+ Layer expansion ratio.
+ s : int
+ strides
+ """
+
+ def __init__(self, c_in, c_out, t, s, **kwargs):
+ super(BottleNeck, self).__init__(**kwargs)
+ self.use_shortcut = s == 1 and c_in == c_out
+ with self.name_scope():
+ self.out = nn.HybridSequential()
+ self.out.add(
+ nn.Conv2D(c_in * t, 1, padding=0, use_bias=False),
+ nn.BatchNorm(scale=True),
+ nn.Activation('relu'),
+
+ nn.Conv2D(c_in * t, 3, strides=s, padding=1, groups=c_in * t,
use_bias=False),
+ nn.BatchNorm(scale=True),
+ nn.Activation('relu'),
+
+ nn.Conv2D(c_out, 1, padding=0, use_bias=False),
+ nn.BatchNorm(scale=True),
+ )
+
+ def hybrid_forward(self, F, x):
+ out = self.out(x)
+ if self.use_shortcut:
+ out = F.elemwise_add(out, x)
+ return out
+
+
+# Net
+class MobileNetV2(nn.HybridBlock):
+ r"""MobileNetV2 model from the
+ `"Inverted Residuals and Linear Bottlenecks: Mobile Networks for
Classification, Detection and Segmentation"
+ <https://arxiv.org/abs/1801.04381>`_ paper.
+
+ Parameters
+ ----------
+ classes : int, default 1000
+ Number of classes for the output layer.
+ w : float, default 1.0
+ The width multiplier for controling the model size. The actual number
of channels
+ is equal to the original channel size multiplied by this multiplier.
+ """
+
+ def __init__(self, classes=1000, w=1.0, **kwargs):
+ super(MobileNetV2, self).__init__(**kwargs)
+ with self.name_scope():
+ self.features = nn.HybridSequential(prefix='features_')
+ with self.features.name_scope():
+ self.features.add(
+ nn.Conv2D(int(32 * w), 3, strides=2, padding=1,
use_bias=False),
+ nn.BatchNorm(scale=True),
+ nn.Activation('relu')
+ )
+
+ self.features.add(BottleNeck(c_in=int(32 * w), c_out=int(16 *
w), t=1, s=1))
+
+ self.features.add(BottleNeck(c_in=int(16 * w), c_out=int(24 *
w), t=6, s=2))
+ self.features.add(BottleNeck(c_in=int(24 * w), c_out=int(24 *
w), t=6, s=1))
+
+ self.features.add(BottleNeck(c_in=int(24 * w), c_out=int(32 *
w), t=6, s=2))
+ self.features.add(BottleNeck(c_in=int(32 * w), c_out=int(32 *
w), t=6, s=1))
+ self.features.add(BottleNeck(c_in=int(32 * w), c_out=int(32 *
w), t=6, s=1))
+
+ self.features.add(BottleNeck(c_in=int(32 * w), c_out=int(64 *
w), t=6, s=2))
+ self.features.add(BottleNeck(c_in=int(64 * w), c_out=int(64 *
w), t=6, s=1))
+ self.features.add(BottleNeck(c_in=int(64 * w), c_out=int(64 *
w), t=6, s=1))
+ self.features.add(BottleNeck(c_in=int(64 * w), c_out=int(64 *
w), t=6, s=1))
+
+ self.features.add(BottleNeck(c_in=int(64 * w), c_out=int(96 *
w), t=6, s=1))
+ self.features.add(BottleNeck(c_in=int(96 * w), c_out=int(96 *
w), t=6, s=1))
+ self.features.add(BottleNeck(c_in=int(96 * w), c_out=int(96 *
w), t=6, s=1))
+
+ self.features.add(BottleNeck(c_in=int(96 * w), c_out=int(160 *
w), t=6, s=2))
+ self.features.add(BottleNeck(c_in=int(160 * w), c_out=int(160
* w), t=6, s=1))
+ self.features.add(BottleNeck(c_in=int(160 * w), c_out=int(160
* w), t=6, s=1))
+
+ self.features.add(BottleNeck(c_in=int(160 * w), c_out=int(320
* w), t=6, s=1))
+
+ last_channels = int(1280 * w) if w > 1.0 else 1280
+
+ self.features.add(
+ nn.Conv2D(last_channels, 1, strides=1, padding=0,
use_bias=False),
+ nn.BatchNorm(scale=True),
+ nn.Activation('relu'),
+ )
+ self.features.add(nn.GlobalAvgPool2D())
+
+ self.output = nn.Conv2D(channels=classes, kernel_size=1,
strides=1, padding=0, use_bias=False, prefix='pred_')
+ self.flatten = nn.Flatten(prefix='flat_')
+
+ def hybrid_forward(self, F, x):
+ x = self.features(x)
+ x = self.output(x)
+ x = self.flatten(x)
+ return x
+
+
+def get_mobilenetv2(w, pretrained=False, ctx=cpu(),
+ root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ r"""MobileNetV2 model from the
+ `"Inverted Residuals and Linear Bottlenecks: Mobile Networks for
Classification, Detection and Segmentation"
+ <https://arxiv.org/abs/1801.04381>`_ paper.
+
+ Parameters
+ ----------
+ w : float
+ The width multiplier for controling the model size. The actual number
of channels
+ is equal to the original channel size multiplied by this multiplier.
+ pretrained : bool, default False
+ Whether to load the pretrained weights for model.
+ ctx : Context, default CPU
+ The context in which to load the pretrained weights.
+ root : str, default '~/.mxnet/models'
+ Location for keeping the model parameters.
+ """
+ net = MobileNetV2(w=w, **kwargs)
+ if pretrained:
+ from ..model_store import get_model_file
+ version_suffix = '{0:.2f}'.format(w)
+ if version_suffix in ('1.00', '0.50'):
+ version_suffix = version_suffix[:-1]
+ net.load_params(get_model_file('mobilenetv2_%s' % version_suffix,
root=root), ctx=ctx)
+ return net
+
+
+def mobilenetv2_1_0(**kwargs):
Review comment:
provide the same variants as the v1 mobilenet
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services