This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c728d5  [RELAY][MXNET][FRONTEND] add support for MXNET numpy 
operators (#6054)
4c728d5 is described below

commit 4c728d532b987ae4de464597bbfccc687913ad2d
Author: sandyhu533 <[email protected]>
AuthorDate: Sat Aug 22 04:47:47 2020 +0800

    [RELAY][MXNET][FRONTEND] add support for MXNET numpy operators (#6054)
    
    * [RELAY][MXNET][FRONTEND] add supports for OPs in numpy from mxnet
    
    * Update test_forward.py
    
    * Update mxnet.py
    
    * Update mxnet.py
    
    * Update test_forward.py
    
    * update and bugfix
    
    * test for multiple dtypes
    
    * Update test_forward.py
    
    * add data type and optimize coding style
    
    * replace pytest.skip with @pytest.mark.skipif
    
    * Update test_forward.py
    
    * update pytest style
    
    * Update test_forward.py
    
    * Update test_forward.py
    
    * Update test_forward.py
    
    * Update test_forward.py
    
    Co-authored-by: Ubuntu 
<[email protected]>
---
 python/tvm/relay/frontend/mxnet.py          | 105 ++++++++++
 tests/python/frontend/mxnet/test_forward.py | 290 ++++++++++++++++++++--------
 2 files changed, 318 insertions(+), 77 deletions(-)

diff --git a/python/tvm/relay/frontend/mxnet.py 
b/python/tvm/relay/frontend/mxnet.py
index 36ffe54..1b49c1c 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -2166,6 +2166,93 @@ def _mx_broadcast_logical(logical_op):
     return impl
 
 
+def _mx_npi_transpose(inputs, attrs):
+    axes = attrs.get_int_tuple("axes", None)
+    # translate default case
+    axes = None if len(axes) == 0 or axes[0] is None else axes
+    return _op.transpose(inputs[0], axes=axes)
+
+
+def _mx_npi_pad(inputs, attrs):
+    pad_mode = attrs.get_str('mode', None)
+    if pad_mode is None:
+        raise tvm.error.OpAttributeRequired(
+            'Attribute "mode" not found in operator pad.')
+    if pad_mode not in ['constant', 'edge', 'reflect']:
+        raise tvm.error.OpAttributeInvalid(
+            'Value ' + mode + ' in attribute "mode" is not valid')
+    pad_width = attrs.get_int_tuple('pad_width', None)
+    if pad_width is None:
+        raise tvm.error.OpAttributeRequired(
+            'Attribute "pad_width" not found in operator pad.')
+    if None in pad_width:
+        raise tvm.error.OpAttributeInvalid(
+            'Value None in attribute "pad_width" of operator Slice is not 
valid.')
+    constant_values = attrs.get_float('constant_values', 0.0)
+    padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], 
pad_width[1::2]))
+
+    return _op.nn.pad(data=inputs[0],
+                      pad_width=padding,
+                      pad_value=constant_values,
+                      pad_mode=pad_mode)
+
+
+def _mx_npi_concatenate(inputs, attrs):
+    axis = attrs.get_str("axis", "0")
+    if axis == "None":
+        return _op.reshape(_op.concatenate(tuple(inputs), axis=0), (-1,))
+    else:
+        return _op.concatenate(tuple(inputs), axis=int(axis))
+
+
+def _mx_npx_reshape(inputs, attrs):
+    shape = attrs.get_int_tuple("newshape")
+    reverse = attrs.get_bool("reverse", False)
+    shape_list = list(shape)
+    new_shape_list = []
+    for num in shape_list:
+        if num > 0 or num == -1:
+            new_shape_list.append(num)
+        elif num == -2:
+            new_shape_list.append(0)
+        elif num == -4:
+            new_shape_list.append(-2)
+        elif num == -5:
+            new_shape_list.append(-3)
+        elif num == -6:
+            new_shape_list.append(-4)
+        else:
+            raise tvm.error.OpAttributeInvalid(
+                'Shape dimension %d is not supported' % num)
+    shape = tuple(new_shape_list)
+    if reverse:
+        return _op.reverse_reshape(inputs[0], newshape=shape)
+    return _op.reshape(inputs[0], newshape=shape)
+
+
+def _mx_split_v2(inputs, attrs):
+    axis = attrs.get_int("axis")
+    indices = list(attrs.get_int_tuple("indices", []))
+    # remove the prefix '0'
+    if len(indices) != 0 and indices[0] == 0:
+        indices.remove(0)
+    sections = attrs.get_int("sections", 0)
+    indices_or_sections = list(indices) if len(indices) != 0 else sections
+    res = _op.split(inputs[0], indices_or_sections=indices_or_sections, 
axis=axis)
+    if attrs.get_bool("squeeze_axis", False):
+        res = tuple([_op.squeeze(x, axis=[axis]) for x in res])
+    return res
+
+
+def _mx_npi_where_rscalar(inputs, attrs):
+    scalar = attrs.get_float("scalar")
+    dtype = _infer_type(inputs[1]).checked_type.dtype
+    scalar = _expr.const(scalar, dtype=dtype)
+    ones = _op.ones_like(inputs[1])
+    scalar = _op.multiply(ones, scalar)
+    return _op.where(inputs[0], inputs[1], scalar)
+
+
 # Note: due to attribute conversion constraint
 # ops in the identity set must be attribute free
 _identity_list = [
@@ -2322,6 +2409,7 @@ _convert_map = {
     "slice_axis"    : _mx_slice_axis,
     "SliceChannel"  : _mx_split,
     "split"         : _mx_split,
+    "_split_v2"     : _mx_split_v2,
     "SwapAxis"      : _mx_swap_axis,
     "expand_dims"   : _mx_expand_dims,
     "Concat"        : _mx_concat,
@@ -2400,6 +2488,23 @@ _convert_map = {
     "_contrib_quantized_pooling": _qnn_pooling,
     "_contrib_quantized_batch_norm" : _qnn_batch_norm,
     "_sg_mkldnn_fully_connected": _qnn_fully_connected,
+    # numpy
+    "_np_transpose"     : _mx_npi_transpose,
+    "_npi_transpose"    : _mx_npi_transpose,
+    "_npi_pad"          : _mx_npi_pad,
+    "_npi_concatenate"  : _mx_npi_concatenate,
+    "_npx_reshape"      : _mx_npx_reshape,
+    "_np_copy"          : _rename(_op.copy),
+    "_npi_power"              : _rename(_op.power),
+    "_npi_power_scalar"       : _binop_scalar(_op.power),
+    "_npi_multiply"           : _rename(_op.multiply),
+    "_npi_multiply_scalar"    : _binop_scalar(_op.multiply),
+    "_npi_add"                : _rename(_op.add),
+    "_npi_add_scalar"         : _binop_scalar(_op.add),
+    "_npi_where_rscalar"      : _mx_npi_where_rscalar,
+    "_npi_less"               : _rename(_op.less),
+    "_npi_tanh"               : _rename(_op.tanh),
+    "_npi_true_divide_scalar" : _binop_scalar(_op.divide),
 }
 
 # set identity list
diff --git a/tests/python/frontend/mxnet/test_forward.py 
b/tests/python/frontend/mxnet/test_forward.py
index 48ad736..594ffe7 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -27,7 +27,8 @@ import mxnet as mx
 from mxnet import gluon
 from mxnet.gluon.model_zoo import vision
 import model_zoo
-
+import random
+import pytest
 
 def verify_mxnet_frontend_impl(mx_symbol,
                                data_shape=(1, 3, 224, 224),
@@ -1410,80 +1411,215 @@ def test_forward_softmax():
     verify((2, 3, 4), 2, True, np.array([[3, 4, 2], [1, 2, 
1]]).astype('int32'))
 
 
[email protected](not hasattr(mx.sym.np, 'pad'), reason="mx.sym.np.pad 
hasn't been publish yet")
[email protected](
+    "data_shape, pad_width",
+    [((1,1,3,5),(0,0,0,0,1,2,3,4)), ((1,1,3,5,7),(0,0,0,0,1,2,3,4,5,6))]
+)
[email protected]("mode", ["constant", "edge", "reflect"])
[email protected]("dtype", ['float64', 'float32', 'int64', 'int32'])
[email protected]("constant_value", [0.0, 3.0])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
+def test_forward_npi_pad(data_shape, pad_width, mode, dtype, 
constant_value,target, ctx, kind):
+    data_np = np.random.uniform(size=data_shape).astype(dtype)
+    data = mx.sym.var('data')
+    if mode == 'constant':
+        ref_res = mx.ndarray.pad(mx.nd.array(data_np), 
mode=mode,pad_width=pad_width, constant_value=constant_value)
+        mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, 
pad_width=pad_width, constant_values=constant_value)
+    else:
+        ref_res = mx.ndarray.pad(mx.nd.array(data_np), 
mode=mode,pad_width=pad_width)
+        mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, 
pad_width=pad_width)
+    mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, 
dtype=dtype)
+    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+    op_res = intrp.evaluate()(data_np)
+    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+    
[email protected](not hasattr(mx.sym.np, 'pad'), reason="test'll abort with 
Mxnet 1.x, skip for now")
[email protected]("data_shape", [(2,2,2),(2,7,2)])
[email protected]("dtype", ['float64', 'float32', 'int64', 'int32', 
'bool'])
[email protected]("axes", [(1,0,2),None])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
+def test_forward_npi_transpose(data_shape, axes, dtype,target, ctx, kind):
+    data_np = np.random.uniform(size=data_shape).astype(dtype)
+    data = mx.sym.var('data')
+    ref_res = mx.np.transpose(mx.np.array(data_np), axes=axes)
+    mx_sym = mx.sym.np.transpose(data.as_np_ndarray(), axes=axes)
+    mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, 
dtype=dtype)
+    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+    op_res = intrp.evaluate()(data_np)
+    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+
[email protected](
+    "data_shape1, data_shape2, axis",
+    
[((2,2),(2,2),1),((2,4),(2,3),1),((1,3,2),(1,3,5),2),((1,3,3),(1,3,3),1),((1,3),(1,3),0)]
+)
[email protected]("dtype", ['float64', 'float32', 'int64', 'int32'])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
+def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype,target, 
ctx, kind):
+    data_np1 = np.random.uniform(size=data_shape1).astype(dtype)
+    data_np2 = np.random.uniform(size=data_shape2).astype(dtype)
+    data1 = mx.sym.var('data1')
+    data2 = mx.sym.var('data2')
+    ref_res = mx.np.concatenate([mx.np.array(data_np1), 
mx.np.array(data_np2)], axis=axis)
+    mx_sym = mx.sym.np.concatenate([data1.as_np_ndarray(), 
data2.as_np_ndarray()], axis=axis)
+    mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data1": data_shape1, 
"data2": data_shape2}, dtype=dtype)
+    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+    op_res = intrp.evaluate()(data_np1, data_np2)
+    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+
[email protected]("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8)])
[email protected]("dtype", ['float64', 'float32', 'int64', 'int32', 
'bool'])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
+def test_forward_np_copy(data_shape,dtype,target, ctx, kind):
+    data_np = np.random.uniform(size=data_shape).astype(dtype)
+    data = mx.sym.var('data')
+    ref_res = mx.np.copy(mx.np.array(data_np))
+    mx_sym = mx.sym.np.copy(data.as_np_ndarray())
+    mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, 
dtype=dtype)
+    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+    op_res = intrp.evaluate()(data_np)
+    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+
[email protected]("dtype", ['float64', 'float32', 'int64', 'int32', 
'bool'])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
[email protected]("data_shape,out_shape,reverse",
+                         [((2, 3, 8),(-2, -2, 2, -1),False),
+                          ((8, 3, 3, 3, 4, 4),(-6, 2, -1, -4),False),
+                          ((8, 3, 3, 3, 4, 4),(-5, -4),False),
+                          ((8, 3, 3, 3, 3, 8),(-4, -5),True),
+                          ((8, 3, 2, 4, 8),(-4, -1, 2, -6),True)])
+def test_forward_npx_reshape(data_shape,out_shape,dtype,target,reverse, ctx, 
kind):
+    data_np = np.random.uniform(size=data_shape).astype(dtype)
+    data = mx.sym.var('data')
+    ref_res = mx.npx.reshape(mx.np.array(data_np), newshape=out_shape, 
reverse=reverse)
+    mx_sym = mx.sym.npx.reshape(data.as_np_ndarray(), newshape=out_shape, 
reverse=reverse)
+    mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, 
dtype=dtype)
+    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+    op_res = intrp.evaluate()(data_np)
+    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+
[email protected]("data_shape", 
[(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)])
[email protected]("dtype", ['float64', 'float32', 'int64', 'int32'])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
+def test_forward_npi_binary(data_shape,dtype,target, ctx, kind):
+    ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.less]
+    mx_ops = [mx.sym.np.power, mx.sym.np.multiply, mx.sym.np.add, 
mx.sym.np.less]
+    for i in range(len(ref_ops)):
+        ref_op = ref_ops[i]
+        mx_op = mx_ops[i]
+        # mx.np.power only support float type
+        if ref_op == mx.np.power and dtype not in ['float64', 'float32']:
+            continue
+        data_np1 = np.random.uniform(size=data_shape).astype(dtype)
+        data_np2 = np.random.uniform(size=data_shape).astype(dtype)
+        data1 = mx.sym.var('lhs')
+        data2 = mx.sym.var('rhs')
+        ref_res = ref_op(mx.np.array(data_np1), mx.np.array(data_np2))
+        mx_sym = mx_op(data1.as_np_ndarray(), data2.as_np_ndarray())
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"lhs": data_shape, 
"rhs": data_shape}, dtype=dtype)
+        intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+        op_res = intrp.evaluate()(data_np1, data_np2)
+        tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), 
rtol=1e-5)
+
+
[email protected]("data_shape", 
[(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)])
[email protected]("dtype", ['float64', 'float32', 'int64', 'int32'])
[email protected]("target, ctx", ctx_list())
[email protected]("scalar", [1.0,2.0,3.0,4.0])
[email protected]("kind", ["graph", "vm", "debug"])
+def test_forward_npi_binary_scalar(data_shape,dtype,scalar,target, ctx, kind):
+    ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.true_divide]
+    mx_ops = [mx.sym.np.power, mx.sym.np.multiply, mx.sym.np.add, 
mx.sym.np.true_divide]
+    for i in range(len(ref_ops)):
+        ref_op = ref_ops[i]
+        mx_op = mx_ops[i]
+        # mx.np.power only support float type
+        if ref_op == mx.np.power and dtype not in ['float64', 'float32']:
+            continue
+        data_np1 = np.random.uniform(size=data_shape).astype(dtype)
+        data1 = mx.sym.var('lhs')
+        ref_res = ref_op(mx.np.array(data_np1), scalar)
+        mx_sym = mx_op(data1.as_np_ndarray(), scalar)
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"lhs": data_shape}, 
dtype=dtype)
+        intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+        op_res = intrp.evaluate()(data_np1)
+        tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), 
rtol=1e-5)
+
+
[email protected]("data_shape", 
[(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)])
[email protected]("dtype", ['float64', 'float32'])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
+def test_forward_npi_tanh(data_shape,dtype,target, ctx, kind):
+    data_np1 = np.random.uniform(size=data_shape).astype(dtype)
+    data1 = mx.sym.var('data')
+    ref_res = mx.np.tanh(mx.np.array(data_np1))
+    mx_sym = mx.sym.np.tanh(data1.as_np_ndarray())
+    mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, 
dtype=dtype)
+    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+    op_res = intrp.evaluate()(data_np1)
+    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+
[email protected](not hasattr(mx.np, 'where'), reason="mx.np.where hasn't 
been publish yet")
[email protected]("data_shape", [(2,2,2),(2,7,2),(1,8),(2,2),(1,3)])
[email protected]("data_dtype", ['float64', 'float32', 'int64', 
'int32', 'bool'])
[email protected]("cond_dtype", ['float64', 'float32', 'int64', 
'int32', 'bool'])
[email protected]("scalar", [1.0,2.0])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
+def 
test_forward_npi_where_rscalar(data_shape,cond_dtype,data_dtype,scalar,target, 
ctx, kind):
+    if data_dtype == 'bool':
+        scalar = scalar == 0.0
+    cond_np = np.random.uniform(size=data_shape).astype(cond_dtype)
+    data_np = np.random.uniform(size=data_shape).astype(data_dtype)
+    cond = mx.sym.var('condition')
+    data = mx.sym.var('x')
+    ref_res = mx.np.where(mx.np.array(cond_np), mx.np.array(data_np), scalar)
+    mx_sym = mx.sym.np.where(cond.as_np_ndarray(), data.as_np_ndarray(), 
scalar)
+    dtypeDic = {}
+    dtypeDic["condition"] = cond_dtype
+    dtypeDic["x"] = data_dtype
+    mod, _ = relay.frontend.from_mxnet(
+        mx_sym, shape={"condition": data_shape, "x": data_shape}, 
+        dtype=dtypeDic)
+    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+    op_res = intrp.evaluate()(cond_np, data_np)
+    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+
[email protected]("dtype", ['float64', 'float32', 'int64', 'int32', 
'bool'])
[email protected]("target, ctx", ctx_list())
[email protected]("kind", ["graph", "vm", "debug"])
[email protected]("data_shape, axis, indices_or_sections, 
squeeze_axis", 
+                         
[((3,2,1),1,2,False),((3,2,1),0,3,False),((3,2,1),0,3,True),((3,2,1),0,(1,2),False)])
+def test_forward_split_v2(data_shape, axis, dtype, indices_or_sections, 
squeeze_axis, target, ctx, kind):
+    data_np = np.random.uniform(size=data_shape).astype(dtype)
+    data = mx.sym.var('data')
+    ref_res = mx.ndarray.split_v2(mx.nd.array(data_np), indices_or_sections, 
axis=axis, squeeze_axis=squeeze_axis)
+    mx_sym = mx.sym.split_v2(data.as_nd_ndarray(), indices_or_sections, 
axis=axis, squeeze_axis=squeeze_axis)
+    mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, 
dtype=dtype)
+    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+    op_res = intrp.evaluate()(data_np)
+    op_res_ = []
+    for arr in op_res:
+        op_res_.append(arr.asnumpy().tolist())
+    ref_res_ = []
+    for arr in  ref_res:
+        ref_res_.append(arr.asnumpy().tolist())
+    tvm.testing.assert_allclose(op_res_, ref_res_, rtol=1e-5)
+
+
 if __name__ == '__main__':
-    test_forward_mlp()
-    test_forward_vgg()
-    test_forward_resnet()
-    test_forward_leaky_relu()
-    test_forward_elu()
-    test_forward_rrelu()
-    test_forward_prelu()
-    test_forward_gelu()
-    test_forward_softrelu()
-    test_forward_softmin()
-    test_forward_fc_flatten()
-    test_forward_clip()
-    test_forward_split()
-    test_forward_split_squeeze()
-    test_forward_expand_dims()
-    test_forward_pad()
-    test_forward_slice()
-    test_forward_pooling()
-    test_forward_pooling3d()
-    test_forward_adaptive_pooling()
-    test_forward_lrn()
-    test_forward_ones()
-    test_forward_zeros()
-    test_forward_ones_like()
-    test_forward_zeros_like()
-    test_forward_argmax()
-    test_forward_argmin()
-    test_forward_where()
-    test_forward_arange()
-    test_forward_broadcast_ops()
-    test_forward_broadcast_to()
-    test_forward_logical_not()
-    test_forward_elemwise_ops()
-    test_forward_unary_ops()
-    test_forward_scalar_ops()
-    test_forward_slice_like()
-    test_forward_slice_axis()
-    test_forward_sequence_reverse()
-    test_forward_l2_normalize()
-    test_forward_shape_array()
-    test_forward_squeeze()
-    test_forward_broadcast_axis()
-    test_forward_full()
-    test_forward_embedding()
-    test_forward_smooth_l1()
-    test_forward_take()
-    test_forward_gather_nd()
-    test_forward_bilinear_resize()
-    test_forward_rnn_layer()
-    test_forward_Crop()
-    test_forward_argsort()
-    test_forward_topk()
-    test_forward_sequence_mask()
-    test_forward_contrib_div_sqrt_dim()
-    test_forward_batch_norm()
-    test_forward_instance_norm()
-    test_forward_layer_norm()
-    test_forward_one_hot()
-    test_forward_depth_to_space()
-    test_forward_space_to_depth()
-    test_forward_convolution()
-    test_forward_deconvolution()
-    test_forward_cond()
-    test_forward_make_loss()
-    test_forward_unravel_index()
-    test_forward_swap_axis()
-    test_forward_correlation()
-    test_forward_grid_generator()
-    test_forward_bilinear_sampler()
-    test_forward_arange_like()
-    test_forward_interleaved_matmul_selfatt_qk()
-    test_forward_interleaved_matmul_selfatt_valatt()
-    test_forward_box_decode()
-    test_forward_amp_multicast()
-    test_forward_amp_cast()
-    test_forward_softmax()
+    pytest.main(['test_forward.py'])

Reply via email to