kohr-h commented on a change in pull request #13571: Add pixelshuffle layers
URL: https://github.com/apache/incubator-mxnet/pull/13571#discussion_r240000688
##########
File path: python/mxnet/gluon/contrib/nn/basic_layers.py
##########
@@ -235,3 +235,135 @@ def _get_num_devices(self):
def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean,
running_var,
name='fwd', **self._kwargs)
+
+class PixelShuffle1D(HybridBlock):
+
+ """Pixel-shuffle layer for upsampling in 1 dimension."""
+
+ def __init__(self, factor):
+ """
+
+ Parameters
+ ----------
+ factor : int or 1-tuple of int
+ Upsampling factor, applied to the ``W`` dimension.
+
+ Inputs:
+ - **data**: Tensor of shape ``(N, f*C, W)``.
+ Outputs:
+ - **out**: Tensor of shape ``(N, C, f*W)``.
+
+ Examples
+ --------
+ >>> pxshuf = PixelShuffle1D(2)
+ >>> x = mx.nd.zeros((1, 8, 3))
+ >>> pxshuf(x).shape
+ (1, 4, 6)
+ """
+ super().__init__()
+ self._factor = int(factor)
+
+ def hybrid_forward(self, F, x):
+ """Perform pixel-shuffling on the input."""
+ f = self._factor
+ # (N, f*C, W)
+ x = F.reshape(x, (0, -4, -1, f, 0)) # (N, C, f, W)
+ x = F.reshape(x, (0, 0, -3)) # (N, C, f*W)
+ return x
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, self._factor)
+
+
+class PixelShuffle2D(HybridBlock):
+
+ """Pixel-shuffle layer for upsampling in 2 dimensions."""
+
+ def __init__(self, factor):
+ """
+
+ Parameters
+ ----------
+ factor : int or 2-tuple of int
+ Upsampling factors, applied to the ``H`` and ``W`` dimensions,
+ in that order.
+
+ Inputs:
+ - **data**: Tensor of shape ``(N, f1*f2*C, H, W)``.
+ Outputs:
+ - **out**: Tensor of shape ``(N, C, f1*H, f2*W)``.
+
+ Examples
+ --------
+ >>> pxshuf = PixelShuffle2D((2, 3))
+ >>> x = mx.nd.zeros((1, 12, 3, 5))
+ >>> pxshuf(x).shape
+ (1, 2, 6, 15)
+ """
+ super().__init__()
+ try:
+ self._factors = (int(factor),) * 2
+ except TypeError:
+ self._factors = tuple(int(fac) for fac in factor)
+ assert len(self._factors) == 2, "wrong length
{}".format(len(self._factors))
+
+ def hybrid_forward(self, F, x):
+ """Perform pixel-shuffling on the input."""
+ f1, f2 = self._factors
+ # (N, f1*f2*C, H, W)
+ x = F.reshape(x, (0, -4, -1, f1 * f2, 0, 0)) # (N, C, f1*f2, H, W)
+ x = F.reshape(x, (0, 0, -3, 0)) # (N, C, f1*f2*H, W)
+ x = F.reshape(x, (0, 0, -4, -1, f2, 0)) # (N, C, f1*H, f2, W)
+ x = F.reshape(x, (0, 0, 0, -3)) # (N, C, f1*H, f2*W)
+ return x
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, self._factors)
+
+
+class PixelShuffle3D(HybridBlock):
+
+ """Pixel-shuffle layer for upsampling in 3 dimensions."""
+
+ def __init__(self, factor):
+ """
+
+ Parameters
+ ----------
+ factor : int or 3-tuple of int
+ Upsampling factors, applied to the ``D``, ``H`` and ``W``
+ dimensions, in that order.
+
+ Inputs:
+ - **data**: Tensor of shape ``(N, f1*f2*f3*C, D, H, W)``.
+ Outputs:
+ - **out**: Tensor of shape ``(N, C, f1*D,f2*H, f3*W)``.
+
+ Examples
+ --------
+ >>> pxshuf = PixelShuffle3D((2, 3, 4))
+ >>> x = mx.nd.zeros((1, 48, 3, 5, 7))
+ >>> pxshuf(x).shape
+ (1, 2, 6, 15, 28)
+ """
+ super().__init__()
+ try:
+ self._factors = (int(factor),) * 3
+ except TypeError:
+ self._factors = tuple(int(fac) for fac in factor)
+ assert len(self._factors) == 3, "wrong length
{}".format(len(self._factors))
+
+ def hybrid_forward(self, F, x):
+ """Perform pixel-shuffling on the input."""
+ f1, f2, f3 = self._factors
+ # (N,
f1*f2*f3*C, D, H, W)
+ x = F.reshape(x, (0, -4, -1, f1 * f2 * f3, 0, 0, 0)) # (N, C,
f1*f2*f3, D, H, W)
Review comment:
True, but if the answer is "this will never work on GPU", then it's quicker
to ask than to write tests and wait for CI to run them.
But sure, I'll add some tests.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services