sxjscience commented on a change in pull request #16229: Pseudo 2D transpose
kernel
URL: https://github.com/apache/incubator-mxnet/pull/16229#discussion_r336165403
##########
File path: tests/python/unittest/test_operator.py
##########
@@ -2876,6 +2876,45 @@ def test_transpose():
@with_seed()
+def test_pseudo2dtranspose():
+ def getTwoInts(mn, mx):
+ n1 = np.random.randint(mn, mx)
+ n2 = np.random.randint(mn, mx-1)
+ n2 = n2 if n2 < n1 else n2+1
+ return tuple(np.sort([n1, n2]))
+
+ def getTranspAxes(ndim):
+ axes = list(range(ndim))
+ n1, n2 = getTwoInts(0,ndim)
+ return tuple(axes[:n1]+axes[n2:]+axes[n1:n2])
+
+ for ndim in range(2, 7):
+ for dt in ['int8', 'half', 'int32', 'int64']:
+ for _ in range(5):
+ dims = list(np.random.randint(5, 20, size=ndim))
+ axes = getTranspAxes(ndim)
+ x = mx.nd.array(np.random.normal(size=dims), dtype=dt)
+ y = mx.nd.transpose(x, axes=axes)
+ assert_allclose(np.transpose(x.asnumpy(), axes=axes),
y.asnumpy())
+
+
+@with_seed()
+def test_big_transpose():
+ n = [1]
+ d = list(np.random.randint(132, 160, size=1))
+ hw = list(np.random.randint(256, 320, size=2))
+ c = [10]
+ dims = n + d + hw + c
+ axes = (0,4,1,2,3)
+ x_np = np.random.normal(size=dims).astype('uint8')
+ x = mx.nd.array(x_np, dtype='uint8')
+ y = mx.nd.transpose(x, axes=axes)
+ assert_allclose(np.transpose(x_np, axes=axes), y.asnumpy().astype('uint8'))
+ axes = (0,2,3,4,1)
+ z = mx.nd.transpose(y, axes=axes)
+ assert_allclose(x_np, z.asnumpy().astype('uint8'))
+
+
Review comment:
@ChaiBapchya Sorry, my bad. I should have caught this. Needs to be more
careful next time.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services