This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9fdb86d3f6 [Relax][ONNX] Expand op support for ONNX frontend (#17427)
9fdb86d3f6 is described below

commit 9fdb86d3f6bccc41a772328b5b0442908bc9f9a9
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Oct 3 22:36:55 2024 +0800

    [Relax][ONNX] Expand op support for ONNX frontend (#17427)
    
    * [Relax][ONNX] Expand op support for ONNX frontend
    
    This PR adds a variety of ONNX ops to the Relax frontend, including:
    
    - Acos
    - Acosh
    - And
    - Asin
    - Asinh
    - Atan
    - Atanh
    - BitwiseAnd
    - BitwiseOr
    - BitwiseXor
    - Ceil
    - ConcatFromSequence
    - ConvTranspose
    - Cosh
    - DepthToSpace
    - FastGelu
    - Floor
    - GlobalLpPool
    - GlobalMaxPool
    - GreaterOrEqual
    - IsInf
    - IsNaN
    - LeakyRelu
    - LogSoftmax
    - MaxUnpool
    - Mean
    - MeanVarianceNormalization
    - Mish
    - Or
    - PRelu
    - Round
    - Scatter
    - ScatterElements
    - Selu
    - SequenceAt
    - SequenceConstruct
    - SequenceEmpty
    - SequenceErase
    - SequenceInsert
    - SequenceLength
    - Shrink
    - Sinh
    - Size
    - Softplus
    - Softsign
    - SpaceToDepth
    - SplitToSequence
    - Tan
    - ThresholdedRelu
    - TopK
    - Unique
    - Xor
    
    Also remains a few ops that are not supported yet, see the commented out 
ops in the
    ONNX frontend.
    
    * lint
    
    * lint
    
    * lint
    
    * update for ci
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py    | 1302 ++++++++++++++++----
 python/tvm/relax/op/set.py                         |    8 +-
 python/tvm/relax/transform/legalize_ops/nn.py      |    9 +-
 tests/python/relax/test_frontend_onnx.py           |  664 ++++++++--
 tests/python/relax/test_relax_operators.py         |    2 +-
 .../python/relax/test_transform_legalize_ops_nn.py |   47 +
 6 files changed, 1617 insertions(+), 415 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 462d1cf92c..5777f51fe2 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -34,14 +34,15 @@ If this fails, there may still be dynamic operations in the 
model.
 Not all TVM kernels currently support dynamic shapes, please file an issue on
 github.com/apache/tvm/issues if you hit an error with dynamic kernels.
 """
+import math
 import warnings
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import numpy as _np
 import onnx.onnx_ml_pb2
 
 import tvm
-from tvm import relax, tir, topi
+from tvm import TVMError, relax, tir, topi
 from tvm.ir import IRModule
 from tvm.ir.supply import NameSupply
 from tvm.tir.generic import cast
@@ -236,28 +237,176 @@ class MatMul(OnnxOpConverter):
         return relax.op.matmul(inputs[0], inputs[1])
 
 
-class Div(OnnxOpConverter):
-    """Converts an onnx Div node into an equivalent Relax expression."""
+class BinaryBase(OnnxOpConverter):
+    """Converts an onnx BinaryBase node into an equivalent Relax expression."""
+
+    numpy_op: Callable = None
+    relax_op: Callable = None
 
     @classmethod
-    def _impl_v14(cls, bb, inputs, attr, params):
+    def _impl_v1(cls, bb, inputs, attr, params):
+        if cls.numpy_op is None or cls.relax_op is None:
+            raise ValueError("Numpy and Relax operators must be defined for 
BinaryBase.")
         if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = inputs[0].data.numpy() / inputs[1].data.numpy()
+            output = cls.numpy_op(  # pylint: disable=not-callable
+                inputs[0].data.numpy(), inputs[1].data.numpy()
+            )
             return relax.const(output, inputs[0].struct_info.dtype)
         if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
             x = (
-                int(inputs[0].value)
+                _np.array(inputs[0].value)
                 if isinstance(inputs[0], relax.PrimValue)
                 else inputs[0].data.numpy()
             )
             y = (
-                int(inputs[1].value)
+                _np.array(inputs[0].value)
                 if isinstance(inputs[1], relax.PrimValue)
                 else inputs[1].data.numpy()
             )
-            return relax.PrimValue(int(x / y))
+            return relax.PrimValue(cls.numpy_op(x, y))  # pylint: 
disable=not-callable
+
+        return cls.relax_op(inputs[0], inputs[1])  # pylint: 
disable=not-callable
+
+
+class Add(BinaryBase):
+    """Converts an onnx Add node into an equivalent Relax expression."""
+
+    numpy_op = _np.add
+    relax_op = relax.op.add
+
+
+class Sub(BinaryBase):
+    """Converts an onnx Sub node into an equivalent Relax expression."""
+
+    numpy_op = _np.subtract
+    relax_op = relax.op.subtract
+
+
+class Mul(BinaryBase):
+    """Converts an onnx Mul node into an equivalent Relax expression."""
+
+    numpy_op = _np.multiply
+    relax_op = relax.op.multiply
+
+
+class Div(BinaryBase):
+    """Converts an onnx Div node into an equivalent Relax expression."""
+
+    numpy_op = _np.divide
+    relax_op = relax.op.divide
+
+
+class Pow(BinaryBase):
+    """Converts an onnx Pow node into an equivalent Relax expression."""
+
+    numpy_op = _np.power
+    relax_op = relax.op.power
+
+
+class And(BinaryBase):
+    """Converts an onnx And node into an equivalent Relax expression."""
+
+    numpy_op = _np.logical_and
+    relax_op = relax.op.logical_and
 
-        return relax.op.divide(inputs[0], inputs[1])
+
+class Or(BinaryBase):
+    """Converts an onnx Or node into an equivalent Relax expression."""
+
+    numpy_op = _np.logical_or
+    relax_op = relax.op.logical_or
+
+
+class Xor(BinaryBase):
+    """Converts an onnx Xor node into an equivalent Relax expression."""
+
+    numpy_op = _np.logical_xor
+    relax_op = relax.op.logical_xor
+
+
+class Less(BinaryBase):
+    """Converts an onnx Less node into an equivalent Relax expression."""
+
+    numpy_op = _np.less
+    relax_op = relax.op.less
+
+
+class LessOrEqual(BinaryBase):
+    """Converts an onnx LessEqual node into an equivalent Relax expression."""
+
+    numpy_op = _np.less_equal
+    relax_op = relax.op.less_equal
+
+
+class Greater(BinaryBase):
+    """Converts an onnx Greater node into an equivalent Relax expression."""
+
+    numpy_op = _np.greater
+    relax_op = relax.op.greater
+
+
+class GreaterOrEqual(BinaryBase):
+    """Converts an onnx GreaterEqual node into an equivalent Relax 
expression."""
+
+    numpy_op = _np.greater_equal
+    relax_op = relax.op.greater_equal
+
+
+class Equal(OnnxOpConverter):
+    """Converts an onnx Equal node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        if all([isinstance(inp, relax.Constant) for inp in inputs]):
+            output = inputs[0].data.numpy() == inputs[1].data.numpy()
+            return relax.const(output, output.dtype)
+        elif all([isinstance(inp, (relax.Constant, relax.ShapeExpr)) for inp 
in inputs]):
+            lhs = get_prim_expr_list(inputs[0])
+            rhs = get_prim_expr_list(inputs[1])
+            if len(lhs) != len(rhs):
+                raise ValueError("Cannot compare two tensors with different 
shapes")
+            output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)]
+            return relax.const(output, "bool")
+        return relax.op.equal(inputs[0], inputs[1])
+
+
+class BitwiseBase(BinaryBase):
+    """Converts an onnx BitwiseBase node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def base_impl(cls, bb, inputs, attr, params, py_func, relax_op):
+        valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", 
"uint32", "uint64"]
+        for num, inp in enumerate(inputs):
+            if inp.struct_info.dtype not in valid_types:
+                raise ValueError(
+                    f"Bitwise operations expect all inputs to have integer 
types, "
+                    f"got {inp.struct_info.dtype} for input {num}"
+                )
+        return BinaryBase.base_impl(bb, inputs, attr, params, py_func, 
relax_op)
+
+
+class BitwiseAnd(BitwiseBase):
+    """Converts an onnx BitwiseAnd node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v18(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, 
relax.op.bitwise_and)
+
+
+class BitwiseOr(BitwiseBase):
+    """Converts an onnx BitwiseOr node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v18(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, 
relax.op.bitwise_or)
+
+
+class BitwiseXor(BitwiseBase):
+    """Converts an onnx BitwiseXor node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v18(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, 
relax.op.bitwise_xor)
 
 
 class Sigmoid(OnnxOpConverter):
@@ -277,6 +426,15 @@ class Softmax(OnnxOpConverter):
         return relax.op.nn.softmax(inputs[0], axis=axis)
 
 
+class LogSoftmax(OnnxOpConverter):
+    """Converts an onnx LogSoftmax node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", -1)
+        return relax.op.nn.log_softmax(inputs[0], axis=axis)
+
+
 class Transpose(OnnxOpConverter):
     """Converts an onnx Transpose node into an equivalent Relax expression."""
 
@@ -375,67 +533,6 @@ class Concat(OnnxOpConverter):
         return relax.op.concat(inputs, axis=axis)
 
 
-class Add(OnnxOpConverter):
-    """Convert an onnx Add node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = inputs[0].data.numpy() + inputs[1].data.numpy()
-            return relax.const(output, output.dtype)
-        # If primvalues are involved, handle them directly.
-        if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
-            x = (
-                int(inputs[0].value)
-                if isinstance(inputs[0], relax.PrimValue)
-                else inputs[0].data.numpy()
-            )
-            y = (
-                int(inputs[1].value)
-                if isinstance(inputs[1], relax.PrimValue)
-                else inputs[1].data.numpy()
-            )
-            return relax.PrimValue(int(x + y))
-        return relax.op.add(inputs[0], inputs[1])
-
-
-class Sum(OnnxOpConverter):
-    """Convert an onnx Sum node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
-        for in_index in range(len(inputs) - 1):
-            inputs[in_index + 1] = relax.op.add(inputs[in_index], 
inputs[in_index + 1])
-
-        return inputs[len(inputs) - 1]
-
-
-class Mul(OnnxOpConverter):
-    """Convert an onnx Mul node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        # When all inputs are constant, directly multiply.
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = inputs[0].data.numpy() * inputs[1].data.numpy()
-            return relax.const(output, output.dtype)
-        # If primvalues are involved, handle them directly.
-        if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
-            x = (
-                int(inputs[0].value)
-                if isinstance(inputs[0], relax.PrimValue)
-                else inputs[0].data.numpy()
-            )
-            y = (
-                int(inputs[1].value)
-                if isinstance(inputs[1], relax.PrimValue)
-                else inputs[1].data.numpy()
-            )
-            return relax.PrimValue(int(x * y))
-
-        return relax.op.multiply(inputs[0], inputs[1])
-
-
 class Cast(OnnxOpConverter):
     """Convert an onnx Cast node into an equivalent Relax expression."""
 
@@ -482,8 +579,38 @@ class Gather(OnnxOpConverter):
             shape_val = data[np_index]
             return relax.PrimValue(shape_val)
 
-        # TODO(jwfromm) Make relax.take work with other indices shape.
-        return bb.emit_te(topi.take, data, indices, axis)
+        return relax.op.take(data, indices, axis)
+
+
+class Scatter(OnnxOpConverter):
+    """Convert an onnx Scatter node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], 
axis=axis)
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        raise ValueError("Scatter is deprecated in ONNX 11")
+
+
+class ScatterElements(OnnxOpConverter):
+    """Convert an onnx ScatterElements node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], 
axis=axis)
+
+
+class Size(OnnxOpConverter):
+    """Convert an onnx Size node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        # TODO(tvm-team): add native support for size op
+        return 
relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))
 
 
 class Gemm(OnnxOpConverter):
@@ -542,29 +669,6 @@ class Reshape(OnnxOpConverter):
         return out
 
 
-class Gelu(OnnxOpConverter):
-    """Operator converter for Gelu from Microsoft onnxruntime contrib opset.
-
-    gelu(x) = 0.5x(1 + erf(x/sqrt(2)))
-    """
-
-    @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
-        return relax.op.nn.gelu(inputs[0])
-
-
-class BiasGelu(OnnxOpConverter):
-    """Operator converter for BiasGelu from Microsoft onnxruntime contrib 
opset.
-
-    bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
-    """
-
-    @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
-        inp = relax.op.add(inputs[0], inputs[1])
-        return relax.op.nn.gelu(inp)
-
-
 class Where(OnnxOpConverter):
     """Convert an onnx Where node into an equivalent Relax expression."""
 
@@ -605,24 +709,6 @@ class Clip(OnnxOpConverter):
         return results
 
 
-class Equal(OnnxOpConverter):
-    """Converts an onnx Equal node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = inputs[0].data.numpy() == inputs[1].data.numpy()
-            return relax.const(output, output.dtype)
-        elif all([isinstance(inp, (relax.Constant, relax.ShapeExpr)) for inp 
in inputs]):
-            lhs = get_prim_expr_list(inputs[0])
-            rhs = get_prim_expr_list(inputs[1])
-            if len(lhs) != len(rhs):
-                raise ValueError("Cannot compare two tensors with different 
shapes")
-            output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)]
-            return relax.const(output, "bool")
-        return relax.op.equal(inputs[0], inputs[1])
-
-
 class Shape(OnnxOpConverter):
     """Converts an onnx Equal node into an equivalent Relax expression."""
 
@@ -643,22 +729,6 @@ class Shape(OnnxOpConverter):
         return data_info.shape
 
 
-class Tanh(OnnxOpConverter):
-    """Converts an onnx Tanh node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        return relax.op.tanh(inputs[0])
-
-
-class Sqrt(OnnxOpConverter):
-    """Converts an onnx Sqrt node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        return relax.op.sqrt(inputs[0])
-
-
 class Trilu(OnnxOpConverter):
     """Given a 2-D matrix or batches of 2-D matrices, returns the upper or
     lower triangular part of the tensor(s)
@@ -691,12 +761,157 @@ class Relu(OnnxOpConverter):
         return relax.op.nn.relu(inputs[0])
 
 
-class Pow(OnnxOpConverter):
-    """Converts an onnx Pow node into an equivalent Relax expression."""
+class Elu(OnnxOpConverter):
+    """Converts an onnx Elu node into an equivalent Relax expression."""
 
     @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        return relax.op.power(inputs[0], inputs[1])
+    def _impl_v1(cls, bb, inputs, attr, params):
+        alpha = float(attr.get("alpha", 1.0))
+        return relax.expr.const(-alpha) * relax.op.nn.relu(
+            relax.expr.const(1.0) - relax.op.exp(inputs[0])
+        ) + relax.op.nn.relu(inputs[0])
+
+
+class Selu(OnnxOpConverter):
+    """Converts an onnx Selu node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        alpha = attr.get("alpha", 1.67326319217681884765625)
+        gamma = attr.get("gamma", 1.05070102214813232421875)
+        return relax.const(gamma) * (
+            relax.const(-alpha) * relax.op.nn.relu(relax.const(1.0) - 
relax.op.exp(inputs[0]))
+            + relax.op.nn.relu(inputs[0])
+        )
+
+
+class Mish(OnnxOpConverter):
+    """Converts an onnx Mish node into an equivalent Relax expression.
+
+    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
+    """
+
+    @classmethod
+    def _impl_v18(cls, bb, inputs, attr, params):
+        dtype = inputs[0].checked_type.dtype
+        return inputs[0] * relax.op.tanh(
+            relax.op.log(relax.const(1.0, dtype) + relax.op.exp(inputs[0]))
+        )
+
+
+class PRelu(OnnxOpConverter):
+    """Converts an onnx PRelu node into an equivalent Relax expression.
+
+    f(x) = slope * x for x < 0, x for x >= 0
+    """
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        x = inputs[0]
+        slope = inputs[1]
+        # TODO(tvm-team): Should add a new op for this.
+        return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope)
+
+
+class ThresholdedRelu(OnnxOpConverter):
+    """Converts an onnx ThresholdedRelu node into an equivalent Relax 
expression.
+
+    f(x) = x for x > alpha, 0 otherwise
+    """
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        x = inputs[0]
+        alpha = attr.get("alpha", 1.0)
+        return relax.op.greater(x, relax.const(alpha)).astype("float32") * x
+
+
+class LeakyRelu(OnnxOpConverter):
+    """Converts an onnx LeakyRelu node into an equivalent Relax expression.
+
+    f(x) = x for x > 0, alpha * x otherwise
+    """
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        x = inputs[0]
+        alpha = attr.get("alpha", 0.01)
+        return relax.op.nn.leakyrelu(x, alpha)
+
+
+class Gelu(OnnxOpConverter):
+    """Operator converter for Gelu from Microsoft onnxruntime contrib opset.
+
+    gelu(x) = 0.5x(1 + erf(x/sqrt(2)))
+    """
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return relax.op.nn.gelu(inputs[0])
+
+
+class FastGelu(OnnxOpConverter):
+    """Operator converter for FastGelu from Microsoft onnxruntime contrib 
opset.
+
+    fast_gelu(x) = 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3)))
+                 = 0.5x(1 + tanh((sqrt(2/pi)x + 0.044715(sqrt(2/pi)x^3)))
+                 = 0.5x(1 + tanh(c1 * x + c2 * x^3)))
+    , where
+        c1 = sqrt(2/pi)
+        c2 = 0.044715 * sqrt(2/pi)
+    """
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        if inputs[1]:
+            bias = inputs[1]
+            bias_shape = bias.struct_info.shape
+            assert len(bias_shape) == 1, "bias term must be a 1D tensor"
+            x += bias
+
+        # Declare consts
+        const_dtype = x.struct_info.dtype
+        half = relax.const(0.5, dtype=const_dtype)
+        one = relax.const(1.0, dtype=const_dtype)
+        const1 = relax.const(math.sqrt(2 / math.pi), dtype=const_dtype)
+        const2 = relax.const(0.044715 * math.sqrt(2 / math.pi), 
dtype=const_dtype)
+
+        # Compute FastGelu
+        term1 = relax.op.multiply(half, x)
+        term2 = relax.op.multiply(const1, x)
+        term3 = relax.op.multiply(const2, relax.op.power(x, relax.const(3, 
const_dtype)))
+        tanh = relax.op.tanh(relax.op.add(term2, term3))
+        return relax.op.multiply(term1, relax.op.add(one, tanh))
+
+
+class BiasGelu(OnnxOpConverter):
+    """Operator converter for BiasGelu from Microsoft onnxruntime contrib 
opset.
+
+    bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
+    """
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        inp = relax.op.add(inputs[0], inputs[1])
+        return relax.op.nn.gelu(inp)
+
+
+class Shrink(OnnxOpConverter):
+    """Converts an onnx Shrink node into an equivalent Relax expression.
+
+    f(x) = x + bias if x > lambd, x - bias if x < -lambd, 0 otherwise
+    """
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        x = inputs[0]
+        dtype = x.struct_info.dtype
+        lambd = relax.const(attr.get("lambd", 0.5), dtype)
+        bias = relax.const(attr.get("bias", 0.0), dtype)
+        zeros = relax.op.zeros_like(x)
+        return relax.op.where(x > lambd, x - bias, zeros) + relax.op.where(
+            x < -lambd, x + bias, zeros
+        )
 
 
 class Conv(OnnxOpConverter):
@@ -730,21 +945,55 @@ class Conv(OnnxOpConverter):
                 weight=inputs[1],
                 strides=attr.get("strides", 1),
                 padding=attr.get("pads", 0),
-                dilation=attr.get("dilation", 1),
+                dilation=attr.get("dilations", 1),
                 groups=attr.get("group", 1),
                 data_layout=data_layout,
                 kernel_layout=kernel_layout,
             )
         )
         if inputs[2] is not None:
-            bias = relax.op.reshape(
-                inputs[2],
-                [1, -1]
-                + [
-                    1,
-                ]
-                * (ndim - 2),
-            )
+            bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2))
+            conv_out = relax.op.add(conv_out, bias)
+
+        return conv_out
+
+
+class ConvTranspose(OnnxOpConverter):
+    """Converts an onnx ConvTranspose node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        if hasattr(inputs[0].struct_info, "ndim"):
+            ndim = inputs[0].struct_info.ndim
+        else:
+            ndim = len(inputs[0].struct_info.shape)
+
+        if ndim == 3:
+            op = relax.op.nn.conv1d_transpose
+            data_layout = "NCW"
+            kernel_layout = "IOW"
+        elif ndim == 4:
+            op = relax.op.nn.conv2d_transpose
+            data_layout = "NCHW"
+            kernel_layout = "IOHW"
+        elif ndim == 5:
+            raise NotImplementedError("Relax ConvTranspose3d not supported 
yet")
+        else:
+            raise NotImplementedError("Ndim > 5 not supported for 
convolution.")
+
+        conv_out = op(
+            data=inputs[0],
+            weight=inputs[1],
+            strides=attr.get("strides", 1),
+            padding=attr.get("pads", 0),
+            dilation=attr.get("dilations", 1),
+            groups=attr.get("group", 1),
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+
+        if inputs[2] is not None:
+            bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2))
             conv_out = relax.op.add(conv_out, bias)
 
         return conv_out
@@ -839,17 +1088,6 @@ class ConstantOfShape(OnnxOpConverter):
         return relax.op.broadcast_to(relax.const(value, dtype), shape)
 
 
-class Sub(OnnxOpConverter):
-    """Converts an onnx Sub node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = inputs[0].data.numpy() - inputs[1].data.numpy()
-            return relax.const(output, output.dtype)
-        return relax.op.subtract(inputs[0], inputs[1])
-
-
 class Sin(OnnxOpConverter):
     """Converts an onnx Sin node into an equivalent Relax expression."""
 
@@ -858,6 +1096,14 @@ class Sin(OnnxOpConverter):
         return relax.op.sin(inputs[0])
 
 
+class Sinh(OnnxOpConverter):
+    """Converts an onnx Sinh node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        return relax.op.sinh(inputs[0])
+
+
 class Cos(OnnxOpConverter):
     """Converts an onnx Cos node into an equivalent Relax expression."""
 
@@ -866,6 +1112,78 @@ class Cos(OnnxOpConverter):
         return relax.op.cos(inputs[0])
 
 
+class Cosh(OnnxOpConverter):
+    """Converts an onnx Cosh node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        return relax.op.cosh(inputs[0])
+
+
+class Tan(OnnxOpConverter):
+    """Converts an onnx Tan node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v7(cls, bb, inputs, attr, params):
+        return relax.op.tan(inputs[0])
+
+
+class Tanh(OnnxOpConverter):
+    """Converts an onnx Tanh node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v7(cls, bb, inputs, attr, params):
+        return relax.op.tanh(inputs[0])
+
+
+class Acos(OnnxOpConverter):
+    """Converts an onnx Acos node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v7(cls, bb, inputs, attr, params):
+        return relax.op.acos(inputs[0])
+
+
+class Acosh(OnnxOpConverter):
+    """Converts an onnx Acosh node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        return relax.op.acosh(inputs[0])
+
+
+class Asin(OnnxOpConverter):
+    """Converts an onnx Asin node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v7(cls, bb, inputs, attr, params):
+        return relax.op.asin(inputs[0])
+
+
+class Asinh(OnnxOpConverter):
+    """Converts an onnx Asinh node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        return relax.op.asinh(inputs[0])
+
+
+class Atan(OnnxOpConverter):
+    """Converts an onnx Atan node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v7(cls, bb, inputs, attr, params):
+        return relax.op.atan(inputs[0])
+
+
+class Atanh(OnnxOpConverter):
+    """Converts an onnx Atanh node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        return relax.op.atanh(inputs[0])
+
+
 class Neg(OnnxOpConverter):
     """Converts an onnx Neg node into an equivalent Relax expression."""
 
@@ -877,47 +1195,121 @@ class Neg(OnnxOpConverter):
         return relax.op.negative(inputs[0])
 
 
-class Abs(OnnxOpConverter):
-    """Converts an onnx Abs node into an equivalent Relax expression."""
+class Abs(OnnxOpConverter):
+    """Converts an onnx Abs node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        if isinstance(inputs[0], relax.Constant):
+            output = _np.abs(inputs[0].data.numpy())
+            return relax.const(output, output.dtype)
+        return relax.op.abs(inputs[0])
+
+
+class Reciprocal(OnnxOpConverter):
+    """Converts an onnx Reciprocal node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        input_dtype = inputs[0].struct_info.dtype
+        return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0])
+
+
+class Floor(OnnxOpConverter):
+    """Converts an onnx Floor node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return relax.op.floor(inputs[0])
+
+
+class Ceil(OnnxOpConverter):
+    """Converts an onnx Ceil node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return relax.op.ceil(inputs[0])
+
+
+class Round(OnnxOpConverter):
+    """Converts an onnx Round node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return relax.op.round(inputs[0])
+
+
+class IsInf(OnnxOpConverter):
+    """Converts an onnx IsInf node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v10(cls, bb, inputs, attr, params):
+        return relax.op.isinf(inputs[0])
+
+
+class IsNaN(OnnxOpConverter):
+    """Converts an onnx IsNaN node into an equivalent Relax expression."""
 
     @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        if isinstance(inputs[0], relax.Constant):
-            output = _np.abs(inputs[0].data.numpy())
-            return relax.const(output, output.dtype)
-        return relax.op.abs(inputs[0])
+    def _impl_v9(cls, bb, inputs, attr, params):
+        return relax.op.isnan(inputs[0])
 
 
-class Min(OnnxOpConverter):
-    """Converts an onnx Min node into an equivalent Relax expression."""
+class Sqrt(OnnxOpConverter):
+    """Converts an onnx Sqrt node into an equivalent Relax expression."""
 
     @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return relax.op.sqrt(inputs[0])
+
+
+class MultiInputBase(OnnxOpConverter):
+    """Converts an onnx MultiInputBase node into an equivalent Relax 
expression."""
+
+    numpy_op: Callable = None
+    relax_op: Callable = None
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        if cls.numpy_op is None or cls.relax_op is None:
+            raise NotImplementedError("numpy_op and relax_op must be defined 
for MultiInputBase")
         if all([isinstance(inp, relax.Constant) for inp in inputs]):
             np_inputs = [inp.data.numpy() for inp in inputs]
-            output = _np.minimum(*np_inputs)
+            output = cls.numpy_op(*np_inputs)  # pylint: disable=not-callable
             return relax.const(output, output.dtype)
 
         # Expand inputs, stack them, then perform minimum over the new axis.
         inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in 
inputs]
         stacked_tensor = relax.op.concat(inputs, axis=0)
-        return relax.op.min(stacked_tensor, axis=0)
+        return cls.relax_op(stacked_tensor, axis=0)  # pylint: 
disable=not-callable
+
+
+class Min(MultiInputBase):
+    """Converts an onnx Min node into an equivalent Relax expression."""
+
+    numpy_op = _np.min
+    relax_op = relax.op.min
 
 
-class Max(OnnxOpConverter):
+class Max(MultiInputBase):
     """Converts an onnx Max node into an equivalent Relax expression."""
 
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            np_inputs = [inp.data.numpy() for inp in inputs]
-            output = _np.maximum(*np_inputs)
-            return relax.const(output, output.dtype)
+    numpy_op = _np.max
+    relax_op = relax.op.max
 
-        # Expand inputs, stack them, then perform maximum over the new axis.
-        inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in 
inputs]
-        stacked_tensor = relax.op.concat(inputs, axis=0)
-        return relax.op.max(stacked_tensor, axis=0)
+
+class Mean(MultiInputBase):
+    """Converts an onnx Mean node into an equivalent Relax expression."""
+
+    numpy_op = _np.mean
+    relax_op = relax.op.mean
+
+
+class Sum(MultiInputBase):
+    """Converts an onnx Sum node into an equivalent Relax expression."""
+
+    numpy_op = _np.sum
+    relax_op = relax.op.sum
 
 
 class Log(OnnxOpConverter):
@@ -956,26 +1348,22 @@ class Exp(OnnxOpConverter):
         return relax.op.exp(data)
 
 
-class Less(OnnxOpConverter):
-    """Converts an onnx Less node into an equivalent Relax expression."""
+class Softplus(OnnxOpConverter):
+    """Converts an onnx Softplus node into an equivalent Relax expression."""
 
     @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = _np.less(inputs[0].data.numpy(), inputs[1].data.numpy())
-            return relax.const(output, output.dtype)
-        return relax.op.less(inputs[0], inputs[1])
+    def _impl_v1(cls, bb, inputs, attr, params):
+        dtype = inputs[0].struct_info.dtype
+        return relax.op.log(relax.op.exp(inputs[0]) + relax.const(1, 
dtype=dtype))
 
 
-class LessOrEqual(OnnxOpConverter):
-    """Converts an onnx LessOrEqual node into an equivalent Relax 
expression."""
+class Softsign(OnnxOpConverter):
+    """Converts an onnx Softsign node into an equivalent Relax expression."""
 
     @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = _np.less_equal(inputs[0].data.numpy(), 
inputs[1].data.numpy())
-            return relax.const(output, output.dtype)
-        return relax.op.less_equal(inputs[0], inputs[1])
+    def _impl_v1(cls, bb, inputs, attr, params):
+        dtype = inputs[0].struct_info.dtype
+        return inputs[0] / (relax.op.abs(inputs[0]) + relax.const(1, 
dtype=dtype))
 
 
 class Split(OnnxOpConverter):
@@ -1456,6 +1844,20 @@ class BatchNormalization(OnnxOpConverter):
         )
 
 
+class MeanVarianceNormalization(OnnxOpConverter):
+    """Converts an onnx MeanVarianceNormalization node into an equivalent 
Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        data = inputs[0]
+        axis = attr.get("axes", (0, 2, 3))
+        data_mean = relax.op.mean(data, axis=axis, keepdims=True)
+        data_mean_squared = relax.op.power(data_mean, relax.const(2, 
dtype="float32"))
+        data_squared = relax.op.power(data, relax.const(2, dtype="float32"))
+        data_squared_mean = relax.op.mean(data_squared, axis=axis, 
keepdims=True)
+        return (data - data_mean) / relax.op.sqrt(data_squared_mean - 
data_mean_squared)
+
+
 class Pool(OnnxOpConverter):
     """A helper class for pool op converters."""
 
@@ -1557,16 +1959,79 @@ class GlobalAveragePool(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, bb, inputs, attr, params):
         rank = len(inputs[0].struct_info.shape)
-        if rank == 3:
-            return relax.op.nn.adaptive_avg_pool1d(inputs[0], 1)
-        elif rank == 4:
-            return relax.op.nn.adaptive_avg_pool2d(inputs[0], 1)
-        elif rank == 5:
-            return relax.op.nn.adaptive_avg_pool3d(inputs[0], 1)
-        raise NotImplementedError(
-            "Global average pooling is only implemented for 1D, 2D, and 3D 
kernels, got %dD."
-            % (rank - 2)
+        axes = list(range(2, rank))
+        return relax.op.mean(inputs[0], axis=axes, keepdims=True)
+
+
+class GlobalMaxPool(OnnxOpConverter):
+    """Converts an onnx GlobalMaxPool node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        rank = len(inputs[0].struct_info.shape)
+        axes = list(range(2, rank))
+        return relax.op.max(inputs[0], axis=axes, keepdims=True)
+
+
+class GlobalLpPool(OnnxOpConverter):
+    """Converts an onnx GlobalLpPool node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v2(cls, bb, inputs, attr, params):
+        p = attr.get("p", 2.0)
+        dtype = inputs[0].struct_info.dtype
+        rank = len(inputs[0].struct_info.shape)
+        axes = list(range(2, rank))
+        x_abs = relax.op.abs(inputs[0])
+        x_p = relax.op.power(x_abs, relax.const(p, dtype=dtype))
+        x_sum = relax.op.sum(x_p, axes, keepdims=True)
+        return relax.op.power(x_sum, relax.const(1.0 / p, dtype=dtype))
+
+
+class MaxUnpool(OnnxOpConverter):
+    """Converts an onnx MaxUnpool node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        data = inputs[0]
+        indices = inputs[1]
+        output_shape = inputs[2]
+        kernel_shape = attr.get("kernel_shape")
+        pads = attr.get("pads", [0] * len(kernel_shape) * 2)
+        strides = attr.get("strides", [1] * len(kernel_shape))
+
+        multiplier = _np.concatenate([[1, 1], list(strides)])
+        shape = [v.value for v in data.struct_info.shape]
+        total_output_shape = multiplier * shape
+        # Add extra dimensions from kernel size and stride mismatch
+        total_output_shape += _np.concatenate([[0, 0], list(kernel_shape)], 
axis=0)
+        total_output_shape -= _np.concatenate([[0, 0], list(strides)], axis=0)
+
+        if output_shape is not None:
+            total_output_shape = output_shape
+
+        elif pads is not None:
+            # Get pads in the proper format for relay.
+            pads = _np.concatenate([[0, 0, 0, 0], list(pads)], axis=0)
+            pads = _np.reshape(pads, [-1, 2])
+            # Compute the total padding per axis.
+            total_pad = _np.sum(pads, axis=-1)
+            # Reversing maxpool means that padding actually makes our output 
smaller.
+            total_output_shape = total_output_shape - total_pad
+
+        # Create a tensor of zeros then scatter our data through it.
+        relax_shape = relax.ShapeExpr(total_output_shape.tolist())
+        zeros_tensor = bb.emit(relax.op.zeros(relax_shape, 
data.struct_info.dtype))
+        # We need to flatten all our tensors before scattering.
+        flat_tensor = relax.op.scatter_elements(
+            relax.op.reshape(zeros_tensor, [-1]),
+            relax.op.reshape(indices, [-1]),
+            relax.op.reshape(data, [-1]),
+            axis=0,
         )
+        # Reshape our flattened data back to normal.
+        output = relax.op.reshape(flat_tensor, relax_shape)
+        return output
 
 
 class Flatten(OnnxOpConverter):
@@ -1799,6 +2264,32 @@ class ArgMin(OnnxOpConverter):
         return relax.op.argmin(data, axis, keepdims)
 
 
+class TopK(OnnxOpConverter):
+    """Converts an onnx TopK node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        data = inputs[0]
+        k = inputs[1]
+        if not isinstance(k, relax.Constant):
+            raise ValueError("TopK k must be a constant")
+        k = int(k.data.numpy())
+        axis = attr.get("axis", -1)
+        largest = attr.get("largest", 1)
+        sorted = attr.get("sorted", 1)
+        if sorted != 1:
+            raise ValueError("TopK sorted must be 1 for Relax frontend")
+
+        return relax.op.topk(data, k, axis, ret_type="both", largest=largest)
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        data = inputs[0]
+        k = attr.get("k", 1)
+        axis = attr.get("axis", -1)
+        return relax.op.topk(data, k, axis, ret_type="both")
+
+
 class SkipLayerNormalization(OnnxOpConverter):
     """Converts a microsoft contrib SkipLayerNormalization node into a Relax 
expression."""
 
@@ -1871,26 +2362,6 @@ class EmbedLayerNormalization(OnnxOpConverter):
         return relax.Tuple([ln, mask_index])
 
 
-class Greater(OnnxOpConverter):
-    """Converts an onnx Greater node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = _np.greater(inputs[0].data.numpy(), 
inputs[1].data.numpy())
-            return relax.const(output, output.dtype)
-        return relax.op.greater(inputs[0], inputs[1])
-
-
-class Reciprocal(OnnxOpConverter):
-    """Converts an onnx Reciprocal node into an equivalent Relax expression."""
-
-    @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        input_dtype = inputs[0].struct_info.dtype
-        return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0])
-
-
 class OneHot(OnnxOpConverter):
     """Converts an onnx OneHot node into an equivalent Relax expression."""
 
@@ -1909,15 +2380,16 @@ class OneHot(OnnxOpConverter):
         return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, 
axis, dtype)
 
 
-class Elu(OnnxOpConverter):
-    """Converts an onnx Elu node into an equivalent Relax expression."""
+class Unique(OnnxOpConverter):
+    """Converts an onnx Unique node into an equivalent Relax expression."""
 
     @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
-        alpha = float(attr.get("alpha", 1.0))
-        return relax.expr.const(-alpha) * relax.op.nn.relu(
-            relax.expr.const(1.0) - relax.op.exp(inputs[0])
-        ) + relax.op.nn.relu(inputs[0])
+    def _impl_v11(cls, bb, inputs, attr, params):
+        data = inputs[0]
+        axis = attr.get("axis", None)
+        sorted = bool(attr.get("sorted", 1))
+        # TODO(tvm-team): Add support for return_index, return_inverse, 
return_counts
+        return relax.op.unique(data, sorted=sorted, axis=axis)
 
 
 class HardSigmoid(OnnxOpConverter):
@@ -1966,53 +2438,308 @@ class Not(OnnxOpConverter):
         return relax.op.logical_not(inputs[0])
 
 
+class DepthToSpace(OnnxOpConverter):
+    """Converts an onnx DepthToSpace node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        block_size = int(attr["blocksize"])
+        mode = attr.get("mode", b"DCR").decode("utf-8")
+        b, c, h, w = inputs[0].struct_info.shape
+        if mode == "DCR":
+            x = relax.op.reshape(
+                inputs[0], (b, block_size, block_size, c // (block_size**2), 
h, w)
+            )
+            x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2])
+            return relax.op.reshape(x, (b, c // (block_size**2), h * 
block_size, w * block_size))
+        elif mode == "CRD":
+            x = relax.op.reshape(
+                inputs[0], (b, c // (block_size**2), block_size, block_size, 
h, w)
+            )
+            x = relax.op.permute_dims(x, [0, 1, 4, 2, 5, 3])
+            return relax.op.reshape(x, (b, c // (block_size**2), h * 
block_size, w * block_size))
+        else:
+            raise ValueError(f"Unsupported mode: {mode}, expected DCR or CRD")
+
+
+class SpaceToDepth(OnnxOpConverter):
+    """Converts an onnx SpaceToDepth node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        block_size = int(attr["blocksize"])
+        b, c, h, w = inputs[0].struct_info.shape
+        x = relax.op.reshape(
+            inputs[0], (b, c, h // block_size, block_size, w // block_size, 
block_size)
+        )
+        x = relax.op.permute_dims(x, [0, 3, 5, 1, 2, 4])
+        return relax.op.reshape(
+            x, (b, c * block_size * block_size, h // block_size, w // 
block_size)
+        )
+
+
+class SequenceConstruct(OnnxOpConverter):
+    """Operator converter for sequence construction op."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        # Construct a tuple from input tensors.
+        return relax.Tuple(inputs)
+
+
+class SequenceEmpty(OnnxOpConverter):
+    """Operator converter for sequence empty op."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        # Construct an empty tuple.
+        return relax.Tuple([])
+
+
+class SequenceErase(OnnxOpConverter):
+    """Operator converter for sequence erase op."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        # Erase tensor from sequence on specified position
+        input_sequence = inputs[0]
+
+        if len(inputs) == 2:
+            position = inputs[1]
+            # Non constant position is not supported.
+            if isinstance(position, relax.Constant):
+                position = int(position.data.numpy())
+            else:
+                raise NotImplementedError("Position must be a constant.")
+        else:
+            position = -1
+
+        seq_len = len(input_sequence)
+        if not -seq_len <= position < seq_len:
+            raise ValueError(
+                f"Position is out of bounds, expected [-{seq_len}, {seq_len}), 
got {position}"
+            )
+
+        if position < 0:
+            position = seq_len + position
+        # Convert sequence to a list, insert tensors before erased, and 
repackage as Tuple.
+        tensor_list = [input_sequence[i] for i in range(seq_len) if i != 
position]
+        # Create new tuple and return.
+        return relax.Tuple(tensor_list)
+
+
+class SequenceInsert(OnnxOpConverter):
+    """Operator converter for sequence insert op."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        # Insert a new tensor into a tuple of tensors.
+        input_sequence = inputs[0]
+        tensor_to_insert = inputs[1]
+
+        if len(inputs) == 3:
+            position = inputs[2]
+            # Non constant position is not supported.
+            if isinstance(position, relax.Constant):
+                position = position.data.numpy()
+            else:
+                raise NotImplementedError("Position must be a constant.")
+        else:
+            position = -1
+
+        if position < 0:
+            position = len(input_sequence) + position + 1
+        # Convert sequence to a list, insert new tensor, and repackage as 
Tuple.
+        tensor_list = [input_sequence[i] for i in range(len(input_sequence))]
+        # Insert new tensor.
+        tensor_list.insert(position, tensor_to_insert)
+        # Create new tuple and return.
+        return relax.Tuple(tensor_list)
+
+
+class SequenceLength(OnnxOpConverter):
+    """Operator converter for sequence length op."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        # Get length of input sequence
+        return relax.const(len(inputs[0]), dtype="int64")
+
+
+class ConcatFromSequence(OnnxOpConverter):
+    """Operator converter for sequence concatenation op."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        new_axis = attr.get("new_axis", 0)
+
+        if new_axis == 1:
+            raise NotImplementedError("Insert new axis is not supported yet.")
+
+        return relax.op.concat(inputs[0], axis=axis)
+
+
+class SplitToSequence(OnnxOpConverter):
+    """Operator converter for split to sequence op."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        keepdims = attr.get("keepdims", 1)
+
+        input_tensor = inputs[0]
+        input_shape = input_tensor.struct_info.shape
+
+        # If split is not provided, we split all values along axis.
+        if len(inputs) == 1:
+            split = _np.array(1)
+            if not keepdims:
+                raise NotImplementedError("Only keepdims=1 is supported for 
now")
+        else:
+            split = inputs[1]
+            if not isinstance(split, relax.Constant):
+                raise ValueError("Only constant split supported for 
SplitToSequence")
+            split = split.data.numpy()
+
+        if len(split.shape) == 1 and split.shape[0] > 1:
+            split = _np.cumsum(split)
+            split = list(split[:-1])
+        else:
+            chunk_size, dim_size = int(split), input_shape[axis]
+            if dim_size % chunk_size != 0:
+                raise ValueError(
+                    f"Dimension of size {dim_size} along axis {axis} must be "
+                    f"evenly divisible by chunk size {chunk_size}"
+                )
+            split = dim_size // chunk_size
+
+        output = relax.op.split(input_tensor, split, axis=axis)
+        return output
+
+
+class SequenceAt(OnnxOpConverter):
+    """Operator converter for sequence at op."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        input_sequence = inputs[0]
+        position = inputs[1]
+        assert isinstance(
+            position, relax.Constant
+        ), "Only constant position supported for SequenceAt"
+        position = int(position.data.numpy())
+        return input_sequence[position]
+
+
 def _get_convert_map():
     return {
-        "MatMul": MatMul,
-        "Concat": Concat,
+        # defs/experimental
+        # "Optional": Optional_,
+        # "OptionalHasElement": OptionalHasElement,
+        # "OptionalGetElement": OptionalGetElement,
+        # Binary operators
         "Add": Add,
+        "Sub": Sub,
         "Mul": Mul,
-        "Cast": Cast,
+        "Div": Div,
+        # "Mod": Mod,
+        "Less": Less,
+        "LessOrEqual": LessOrEqual,
+        "Greater": Greater,
+        "GreaterOrEqual": GreaterOrEqual,
+        "Equal": Equal,
+        "BitwiseAnd": BitwiseAnd,
+        "BitwiseOr": BitwiseOr,
+        "BitwiseXor": BitwiseXor,
+        # "BitwiseNot": BitwiseNot,
+        # "BitwiseShift": BitwiseShift,
+        "And": And,
+        "Or": Or,
+        "Xor": Xor,
+        "Not": Not,
+        # Unary operators
+        "Log": Log,
+        "Exp": Exp,
+        "Acos": Acos,
+        "Acosh": Acosh,
+        "Asin": Asin,
+        "Asinh": Asinh,
+        "Atan": Atan,
+        "Atanh": Atanh,
+        "Cos": Cos,
+        "Cosh": Cosh,
+        "Sin": Sin,
+        "Sinh": Sinh,
+        "Tan": Tan,
+        "Tanh": Tanh,
+        "Neg": Neg,
+        "Abs": Abs,
+        "Reciprocal": Reciprocal,
+        "Floor": Floor,
+        "Ceil": Ceil,
+        "Round": Round,
+        "IsInf": IsInf,
+        "IsNaN": IsNaN,
+        "Sqrt": Sqrt,
+        "Relu": Relu,
+        "Selu": Selu,
+        "Mish": Mish,
+        "Trilu": Trilu,
+        "PRelu": PRelu,
+        "LeakyRelu": LeakyRelu,
+        "ThresholdedRelu": ThresholdedRelu,
+        "Elu": Elu,
+        "Gelu": Gelu,
+        "FastGelu": FastGelu,
+        "BiasGelu": BiasGelu,
+        "HardSigmoid": HardSigmoid,
+        "HardSwish": HardSwish,
+        "Sign": Sign,
+        "Softplus": Softplus,
+        "Softsign": Softsign,
+        "Shrink": Shrink,
+        "Erf": Erf,
         "Sum": Sum,
-        "Gather": Gather,
+        "Min": Min,
+        "Max": Max,
+        "Mean": Mean,
+        "Cast": Cast,
         "Gemm": Gemm,
+        "MatMul": MatMul,
+        # "MatMulInteger": MatMulInteger,
+        # "MatMulInteger16": MatMulInteger16,
         "Reshape": Reshape,
-        "Div": Div,
         "Sigmoid": Sigmoid,
         "Softmax": Softmax,
+        "LogSoftmax": LogSoftmax,
+        # "Hardmax": Hardmax,
         "Transpose": Transpose,
         "Unsqueeze": Unsqueeze,
-        "Gelu": Gelu,
-        "BiasGelu": BiasGelu,
         "Where": Where,
+        "Concat": Concat,
         "Clip": Clip,
-        "Equal": Equal,
         "Shape": Shape,
-        "Tanh": Tanh,
-        "Sqrt": Sqrt,
-        "Trilu": Trilu,
-        "Relu": Relu,
-        "Conv": Conv,
         "Pow": Pow,
-        "Erf": Erf,
         "CumSum": CumSum,
         "Squeeze": Squeeze,
         "Constant": Constant,
-        "Sub": Sub,
-        "Sin": Sin,
-        "Cos": Cos,
-        "Neg": Neg,
-        "Abs": Abs,
-        "Min": Min,
-        "Max": Max,
-        "Log": Log,
-        "Exp": Exp,
-        "Less": Less,
-        "LessOrEqual": LessOrEqual,
+        "Gather": Gather,
+        # "GatherElements": GatherElements,
+        # "GatherND": GatherND,
+        "Scatter": Scatter,
+        "ScatterElements": ScatterElements,
+        # "ScatterND": ScatterND,
+        # "Compress": Compress,
+        "Size": Size,
+        # "EyeLike": EyeLike,
+        # Normalization
+        "BatchNormalization": BatchNormalization,
         "LayerNormalization": LayerNormalization,
         "SkipLayerNormalization": SkipLayerNormalization,
         "EmbedLayerNormalization": EmbedLayerNormalization,
         "InstanceNormalization": InstanceNormalization,
+        "MeanVarianceNormalization": MeanVarianceNormalization,
         # defs/reduction
         "ReduceMax": ReduceMax,
         "ReduceMin": ReduceMin,
@@ -2026,6 +2753,7 @@ def _get_convert_map():
         "ReduceL2": ReduceL2,
         "ArgMax": ArgMax,
         "ArgMin": ArgMin,
+        "TopK": TopK,
         "Expand": Expand,
         "ConstantOfShape": ConstantOfShape,
         "Slice": Slice,
@@ -2033,23 +2761,42 @@ def _get_convert_map():
         "Pad": Pad,
         "Split": Split,
         "Tile": Tile,
-        "BatchNormalization": BatchNormalization,
-        "MaxPool": MaxPool,
         "AveragePool": AveragePool,
+        "MaxPool": MaxPool,
+        # "LpPool": LpPool,
         "GlobalAveragePool": GlobalAveragePool,
+        "GlobalMaxPool": GlobalMaxPool,
+        "GlobalLpPool": GlobalLpPool,
+        "MaxUnpool": MaxUnpool,
+        "Conv": Conv,
+        "ConvTranspose": ConvTranspose,
         "Flatten": Flatten,
         "Identity": Identity,
         "Resize": Resize,
         "Einsum": Einsum,
         "Range": Range,
-        "Greater": Greater,
-        "Reciprocal": Reciprocal,
         "OneHot": OneHot,
-        "Elu": Elu,
-        "HardSigmoid": HardSigmoid,
-        "HardSwish": HardSwish,
-        "Sign": Sign,
-        "Not": Not,
+        "Unique": Unique,
+        # "NonZero": NonZero,
+        # "If": If,
+        # "LRN": LRN,
+        # "MaxRoiPool": MaxRoiPool,
+        # "RoiAlign": RoiAlign,
+        # "NonMaxSuppression": NonMaxSuppression,
+        # "GridSample": GridSample,
+        # "Upsample": Upsample,
+        # others
+        "DepthToSpace": DepthToSpace,
+        "SpaceToDepth": SpaceToDepth,
+        # Sequence operators
+        "SequenceConstruct": SequenceConstruct,
+        "SequenceEmpty": SequenceEmpty,
+        "SequenceErase": SequenceErase,
+        "SequenceInsert": SequenceInsert,
+        "SequenceLength": SequenceLength,
+        "ConcatFromSequence": ConcatFromSequence,
+        "SplitToSequence": SplitToSequence,
+        "SequenceAt": SequenceAt,
     }
 
 
@@ -2269,6 +3016,14 @@ class ONNXGraphImporter:
                 "Where",
                 "Cast",
             ]
+            return_tuple_ops = [
+                "SequenceConstruct",
+                "SequenceEmpty",
+                "SequenceErase",
+                "SequenceInsert",
+                "ConcatFromSequence",
+                "SplitToSequence",
+            ]
             for i, inp in enumerate(inputs):
                 if (
                     inp is not None
@@ -2277,11 +3032,17 @@ class ONNXGraphImporter:
                     and op_name not in shape_compatible_ops
                 ):
                     raise ValueError(f"Node {node.name} cannot handle 
ShapeExpr inputs.")
-            op = self._convert_operator(op_name, inputs, attr, self.opset)
-            # Create struct information for the new operator.
-            op = self.bb.normalize(op)
-
-            if not isinstance(op, relax.Tuple):
+            try:
+                op = self._convert_operator(op_name, inputs, attr, self.opset)
+                # Create struct information for the new operator.
+                op = self.bb.normalize(op)
+            except TVMError as err:
+                print(f"Error converting operator {op_name}, with inputs: 
{inputs}")
+                raise err
+
+            if op_name in return_tuple_ops:
+                outputs_num = 1
+            elif not isinstance(op, relax.Tuple):
                 if isinstance(op.checked_type, tvm.ir.type.TupleType):
                     # This is a var bound to a tuple. We need to unpack it and 
create
                     # a new tuple.
@@ -2299,7 +3060,6 @@ class ONNXGraphImporter:
             ), "Missing outputs during conversion. Expected {} but Got {} in 
{}.".format(
                 len(outputs), outputs_num, op_name
             )
-
             if outputs_num == 1:
                 self._nodes[outputs[0]] = op
             else:
@@ -2346,10 +3106,10 @@ class ONNXGraphImporter:
     def _convert_operator(
         self,
         op_name: str,
-        inputs: List[relax.Function],
+        inputs: List[relax.Expr],
         attrs: Dict,
         opset: int,
-    ) -> relax.Function:
+    ) -> relax.Expr:
         """Convert ONNX operator into a Relax operator.
         The converter must specify conversions explicitly for incompatible 
name, and
         apply handlers to operator attributes.
@@ -2386,7 +3146,7 @@ def from_onnx(
     opset: int = None,
     keep_params_in_input: bool = False,
     sanitize_input_names: bool = True,
-) -> Tuple[IRModule, Dict]:
+) -> IRModule:
     """Convert a ONNX model into an equivalent Relax Function.
     ONNX graphs are represented as Python Protobuf objects.
 
@@ -2413,8 +3173,6 @@ def from_onnx(
     -------
     mod : tvm.IRModule
         The relax module for compilation
-    params : dict of str to tvm.nd.NDArray
-        The parameter dict to be used by relax
     """
     # Error if the model version is below 1.1.0
     if model.ir_version < 3:
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
index 4d106ad6d2..0b86e19ce5 100644
--- a/python/tvm/relax/op/set.py
+++ b/python/tvm/relax/op/set.py
@@ -77,7 +77,7 @@ def unique(
         return_inverse = PrimValue(return_inverse)
     if isinstance(return_counts, bool):
         return_counts = PrimValue(return_counts)
-    if axis and isinstance(axis, int):
+    if axis is not None and isinstance(axis, int):
         axis = PrimValue(axis)
     return _ffi_api.unique(  # type: ignore
         x, sorted, return_index, return_inverse, return_counts, axis
@@ -91,6 +91,7 @@ def numpy_unique(
     return_index: int,
     return_inverse: int,
     return_counts: int,
+    axis: Optional[int] = None,
 ) -> tvm.nd.array:
     """Returns the unique elements of the input tensor.
 
@@ -103,8 +104,9 @@ def numpy_unique(
         raise NotImplementedError("missing support return_inverse or 
return_counts set to true")
     x_numpy = x.numpy()
     # TODO(prakalp): use torch.unique instead of numpy when torch is installed 
in ci.
-    output_sorted_numpy, indices = np.unique(x_numpy, return_index=True)
+    output_sorted_numpy, indices = np.unique(x_numpy, return_index=True, 
axis=axis)
+
     if sorted:
         return tvm.nd.array(output_sorted_numpy)
-    output_numpy = [x_numpy.flatten()[index] for index in 
builtins.sorted(indices, reverse=True)]
+    output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis)
     return tvm.nd.array(output_numpy)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 809d231fd3..8317d4504e 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -171,21 +171,16 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
             "and thus cannot be legalized by TOPI"
         )
         return call
-    if call.attrs.groups != 1:
-        logging.info(
-            "TOPI conv1d_transpose does not support groups other than 1, "
-            "and thus cannot be legalized by TOPI"
-        )
-        return call
 
     return bb.call_te(
-        topi.nn.conv1d_transpose_ncw,
+        topi.nn.group_conv1d_transpose_ncw,
         call.args[0],
         call.args[1],
         stride=call.attrs.strides,
         padding=call.attrs.padding,
         out_dtype=call.struct_info.dtype,
         output_padding=call.attrs.output_padding,
+        groups=call.attrs.groups,
         primfunc_name_hint="conv1d_transpose",
     )
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 0e7cfbd7c0..2837ad2185 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -21,7 +21,7 @@ ONNX testcases
 This file is a test script to test Relax ONNX frontend coverage.
 """
 
-from typing import Dict, Optional
+from typing import Dict, List, Literal, Optional
 
 import numpy as np
 import onnx
@@ -118,6 +118,7 @@ def check_correctness(
     tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
     # Legalize any relax ops into tensorir.
     tvm_model = relax.transform.LegalizeOps()(tvm_model)
+    print(tvm_model)
 
     # Separate model from parameters.
     tvm_model, params = relax.frontend.detach_params(tvm_model)
@@ -137,25 +138,31 @@ def check_correctness(
     vm.invoke_stateful("main")
     tvm_output = vm.get_outputs("main")
     # Wrap as a list if there is only one output.
-    if isinstance(tvm_output, tvm.nd.NDArray):
+    if len(ort_output) == 1:
+        # Do not check the output number for TVM
+        # As for sequence output, the TVM output is a Tuple
+        # while the ONNX output number is one, which is a list
         tvm_output = [tvm_output]
-    # If the output is a shape tuple, convert it to an ndarray for comparison.
-    if isinstance(tvm_output, tvm.runtime.ShapeTuple):
-        tvm_output = [tvm.nd.array([int(i) for i in tvm_output])]
 
-    tvm_num_outputs = len(tvm_output)
-    # Shape tuples need to be handled specially.
-    if isinstance(tvm_output, tvm.runtime.ShapeTuple):
-        tvm_num_outputs = 1
+    def _check_output(tvm_out, ort_out):
+        if isinstance(tvm_out, tuple) and isinstance(ort_out, 
(tvm.runtime.ShapeTuple, list)):
+            assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
+            for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
+                _check_output(tvm_out_i, ort_out_i)
+        elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, 
np.ndarray):
+            tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, 
atol=atol)
+        elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and 
isinstance(ort_out, np.ndarray):
+            shape_out = tvm.nd.array([int(i) for i in tvm_out])
+            tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, 
atol=atol)
+        else:
+            raise ValueError(f"Unsupported types: {type(tvm_out)}, 
{type(ort_out)}")
 
     # Check that number of outputs match.
-    assert tvm_num_outputs == len(ort_output), "Unequal number of outputs"
-
+    assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
     for (tvm_out, ort_out) in zip(tvm_output, ort_output):
         # TODO Allow configurable tolerance.
-        # Sometimes None is used to indicate an unused output.
         if ort_out is not None:
-            tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, 
atol=atol)
+            _check_output(tvm_out, ort_out)
 
 
 @pytest.mark.parametrize(
@@ -187,35 +194,61 @@ def test_sanitize(input_names, expected_names):
         assert param.name_hint == expected_names[i]
 
 
-def verify_unary(op_name, shape, attrs={}, domain=None, 
dtype=TensorProto.FLOAT):
+def verify_unary(
+    op_name,
+    shape,
+    attrs={},
+    domain=None,
+    input_dtype=TensorProto.FLOAT,
+    output_dtype=TensorProto.FLOAT,
+    opset=14,
+):
     test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain)
     graph = helper.make_graph(
         [test_node],
         "elemwise_test",
         inputs=[
-            helper.make_tensor_value_info("x", dtype, shape),
+            helper.make_tensor_value_info("x", input_dtype, shape),
         ],
-        outputs=[helper.make_tensor_value_info("y", dtype, shape)],
+        outputs=[helper.make_tensor_value_info("y", output_dtype, shape)],
     )
 
     model = helper.make_model(graph, producer_name="elemwise_test")
-    check_correctness(model)
+    check_correctness(model, opset=opset)
 
 
-def verify_binary(op_name, shape_a, shape_b, shape_c, attrs={}, domain=None):
+def verify_binary(
+    op_name, shape_a, shape_b, shape_c, attrs={}, domain=None, 
dtype=TensorProto.FLOAT, opset=14
+):
     test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, 
domain=domain)
     graph = helper.make_graph(
         [test_node],
         "binary_test",
         inputs=[
-            helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a),
-            helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b),
+            helper.make_tensor_value_info("a", dtype, shape_a),
+            helper.make_tensor_value_info("b", dtype, shape_b),
         ],
-        outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, 
shape_c)],
+        outputs=[helper.make_tensor_value_info("c", dtype, shape_c)],
     )
 
     model = helper.make_model(graph, producer_name="binary_test")
-    check_correctness(model)
+    check_correctness(model, opset=opset)
+
+
+def verify_binary_scalar(op_name, attrs={}, domain=None, 
dtype=TensorProto.INT32, opset=14):
+    a = make_constant_node("a", dtype, [], [4])
+    b = make_constant_node("b", dtype, [], [8])
+    test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, 
domain=domain)
+    graph = helper.make_graph(
+        [a, b, test_node],
+        "binary_test",
+        inputs=[],
+        outputs=[helper.make_tensor_value_info("c", dtype, ())],
+    )
+
+    model = helper.make_model(graph, producer_name="binary_test")
+    # NOTE: explicitly pass inputs to avoid numerical error
+    check_correctness(model, opset=opset)
 
 
 def verify_compare(op_name, shape, attrs={}, domain=None):
@@ -289,16 +322,95 @@ def test_concat():
     verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0})
 
 
-def test_add():
-    verify_binary("Add", [1, 32], [1, 32], [1, 32])
[email protected]("op_name", ["Add", "Sub", "Mul", "Div", "Pow"])
+def test_binary(op_name: str):
+    verify_binary(op_name, [1, 32], [1, 32], [1, 32])
+    verify_binary_scalar(op_name)
+
+
[email protected]("num_inputs", [1, 2, 4])
[email protected]("op_name", ["Min", "Max", "Sum", "Mean"])
+def test_multi_input(op_name: str, num_inputs: int):
+    input_shape = [32, 32]
+    input_var = ["i" + str(i) for i in range(num_inputs)]
+    input_values = [
+        helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for 
var in input_var
+    ]
+    test_node = helper.make_node(op_name, input_var, ["c"])
+    graph = helper.make_graph(
+        [test_node],
+        "multi_input_test",
+        inputs=input_values,
+        outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, 
input_shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="multi_input_test")
+    check_correctness(model)
 
 
-def test_mul():
-    verify_binary("Mul", [1, 32], [1, 32], [1, 32])
[email protected]("op_name", ["Less", "LessOrEqual", "Greater", 
"GreaterOrEqual"])
+def test_compare(op_name: str):
+    verify_compare(op_name, [1, 32])
 
 
-def test_sum():
-    verify_binary("Sum", [1, 32], [1, 32], [1, 32])
[email protected]("op_name", ["And", "Or", "Xor"])
+def test_binary_bool(op_name: str):
+    verify_binary(op_name, [32, 32], [32, 32], [32, 32], 
dtype=TensorProto.BOOL)
+
+
[email protected](
+    "op_name",
+    [
+        "Sin",
+        "Cos",
+        "Tan",
+        "Sinh",
+        "Cosh",
+        "Tanh",
+        "Asin",
+        "Acos",
+        "Atan",
+        "Asinh",
+        "Acosh",
+        "Atanh",
+        "Neg",
+        "Abs",
+        "Log",
+        "Exp",
+        "Not",
+        "Reciprocal",
+        "Floor",
+        "Ceil",
+        "Round",
+        "IsInf",
+        "IsNaN",
+        "Sqrt",
+        "Relu",
+        "Elu",
+        "HardSwish",
+        "Sign",
+        "Softplus",
+        "Softsign",
+        "Erf",
+        "Sigmoid",
+        "Softmax",
+        "LogSoftmax",
+        "Identity",
+    ],
+)
+def test_unary(op_name: str):
+    input_dtype = TensorProto.FLOAT
+    if op_name in [
+        "IsNaN",
+        "IsInf",
+    ]:
+        pytest.skip(f"Skipping test {op_name} because current LegalizeOps does 
not support it.")
+    elif op_name == "Not":
+        input_dtype = TensorProto.BOOL
+        output_dtype = TensorProto.BOOL
+    else:
+        output_dtype = TensorProto.FLOAT
+    verify_unary(op_name, [32, 32], input_dtype=input_dtype, 
output_dtype=output_dtype)
 
 
 @pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT, 
TensorProto.FLOAT16])
@@ -350,6 +462,44 @@ def test_gather():
     _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
 
 
[email protected]("axis", [0, 1, 2])
[email protected](("name", "opset"), [("Scatter", 10), 
("ScatterElements", 11)])
+def test_scatter(axis: int, name: str, opset: int):
+    if axis != 1:
+        pytest.skip("The current topi impl is wrong, which only works for 
axis=1")
+    input_shape = [16, 16, 16]
+    indices_shape = [8, 8, 8]
+    updates_shape = [8, 8, 8]
+    output_shape = [16, 16, 16]
+    node = helper.make_node(name, ["data", "indices", "updates"], ["output"], 
axis=axis)
+    graph = helper.make_graph(
+        [node],
+        "scatter_test",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, 
input_shape),
+            helper.make_tensor_value_info("indices", TensorProto.INT64, 
indices_shape),
+            helper.make_tensor_value_info("updates", TensorProto.FLOAT, 
updates_shape),
+        ],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
output_shape)],
+    )
+    model = helper.make_model(graph, producer_name="scatter_test")
+    indices = np.random.randint(0, 16, indices_shape)
+    check_correctness(model, inputs={"indices": indices}, opset=opset)
+
+
+def test_size():
+    test_node = helper.make_node("Size", ["x"], ["y"])
+    graph = helper.make_graph(
+        [test_node],
+        "size_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 3, 
3])],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.INT64, [3])],
+    )
+
+    model = helper.make_model(graph, producer_name="size_test")
+    check_correctness(model)
+
+
 @pytest.mark.parametrize("alpha", [None, 0.25, 1.0])
 @pytest.mark.parametrize("beta", [None, 0.35, 1.0])
 @pytest.mark.parametrize("useC", [False, True])
@@ -408,18 +558,6 @@ def test_reshape(in_shape, shape, out_shape):
     check_correctness(model, inputs=input_values)
 
 
-def test_div():
-    verify_binary("Div", [32, 32], [32, 32], [32, 32])
-
-
-def test_sigmoid():
-    verify_unary("Sigmoid", [32, 32])
-
-
-def test_softmax():
-    verify_unary("Softmax", [32, 32, 32])
-
-
 def test_transpose():
     verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]})
 
@@ -567,28 +705,33 @@ def test_shape():
     check_correctness(model)
 
 
-def test_tanh():
-    verify_unary("Tanh", [9, 8, 7, 6])
[email protected]("upper", [True, False])
+def test_trilu(upper: bool):
+    verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper})
 
 
-def test_sqrt():
-    verify_unary("Sqrt", [32, 32])
+def test_selu():
+    verify_unary("Selu", [3, 32, 32])
+    verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3})
 
 
-def test_relu():
-    verify_unary("Relu", [32, 32])
[email protected](reason="opset 18 is not supported in CI")
+def test_mish():
+    verify_unary("Mish", [3, 32, 32], opset=18)
 
 
-def test_tril():
-    verify_unary("Trilu", [3, 5, 5], attrs={"upper": False})
+def test_prelu():
+    verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32])
 
 
-def test_triu():
-    verify_unary("Trilu", [3, 5, 5], attrs={"upper": True})
+def test_thresholded_relu():
+    verify_unary("ThresholdedRelu", [3, 32, 32])
+    verify_unary("ThresholdedRelu", [3, 32, 32], attrs={"alpha": -0.01})
 
 
-def test_elu():
-    verify_unary("Elu", [32, 32])
+def test_leakyrelu():
+    verify_unary("LeakyRelu", [32, 32])
+    verify_unary("LeakyRelu", [32, 32], attrs={"alpha": 0.2})
 
 
 def test_hardsigmoid():
@@ -597,30 +740,40 @@ def test_hardsigmoid():
     verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 
0.6})
 
 
-def test_hardswish():
-    verify_unary("HardSwish", [32, 32])
-
-
-def test_sign():
-    verify_unary("Sign", [32, 32])
-
-
-def test_not():
-    verify_unary("Not", [32, 32], dtype=TensorProto.BOOL)
+def test_shrink():
+    verify_unary("Shrink", [32, 32])
+    verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1})
 
 
-def test_conv():
-    def _verify_conv(input_shape, weight_shape, output_shape):
[email protected]("stride", [1, 2])
[email protected]("dilation", [1, 2])
[email protected]("bias", [True, False])
[email protected]("pad", [0, 2])
+def test_conv(stride: int, dilation: int, pad: int, bias: bool):
+    def _verify_conv(input_shape, weight_shape):
+        nd = len(weight_shape) - 2
+        output_shape = [input_shape[0], weight_shape[0]] + [
+            (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) 
// stride + 1
+            for i in range(2, len(input_shape))
+        ]
         bias_shape = [output_shape[1]]
-        conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"])
+        conv_node = helper.make_node(
+            "Conv",
+            inputs=["x", "w"] + (["b"] if bias else []),
+            outputs=["y"],
+            strides=[stride] * nd,
+            dilations=[dilation] * nd,
+            pads=[pad] * nd * 2,
+            group=input_shape[1] // weight_shape[1],
+        )
         graph = helper.make_graph(
             [conv_node],
             "conv_test",
             inputs=[
                 helper.make_tensor_value_info("x", TensorProto.FLOAT, 
input_shape),
                 helper.make_tensor_value_info("w", TensorProto.FLOAT, 
weight_shape),
-                helper.make_tensor_value_info("b", TensorProto.FLOAT, 
bias_shape),
-            ],
+            ]
+            + ([helper.make_tensor_value_info("b", TensorProto.FLOAT, 
bias_shape)] if bias else []),
             outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
output_shape)],
         )
 
@@ -628,20 +781,61 @@ def test_conv():
         check_correctness(model, atol=1e-4)
 
     # Conv1D
-    _verify_conv([3, 12, 32], [4, 12, 3], [3, 4, 30])
+    _verify_conv([3, 4, 32], [4, 4, 3])
+    _verify_conv([3, 4, 32], [2, 4, 3])  # group=2
     # Conv2D
-    _verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30])
+    _verify_conv([3, 4, 32, 32], [4, 4, 3, 3])
+    _verify_conv([3, 4, 32, 32], [2, 4, 3, 3])  # group=2
     # Conv3D
-    _verify_conv([3, 12, 32, 32, 32], [4, 12, 3, 3, 3], [3, 4, 30, 30, 30])
+    _verify_conv([3, 4, 32, 32, 32], [4, 4, 3, 3, 3])
+    _verify_conv([3, 4, 32, 32, 32], [2, 4, 3, 3, 3])  # group=2
+
+
[email protected]("stride", [1, 2])
[email protected]("dilation", [1])
[email protected]("bias", [True, False])
[email protected]("pad", [0, 2])
+def test_conv_transpose(stride: int, dilation: int, pad: int, bias: bool):
+    def _verify_conv_transpose(input_shape, weight_shape):
+        nd = len(weight_shape) - 2
+        output_shape = [input_shape[0], weight_shape[0]] + [
+            (input_shape[i] - 1) * stride - 2 * pad + dilation * 
(weight_shape[i] - 1) + 1
+            for i in range(2, len(input_shape))
+        ]
+        bias_shape = [output_shape[1]]
+        conv_node = helper.make_node(
+            "ConvTranspose",
+            inputs=["x", "w"] + (["b"] if bias else []),
+            outputs=["y"],
+            strides=[stride] * nd,
+            dilations=[dilation] * nd,
+            pads=[pad] * nd * 2,
+            group=input_shape[1] // weight_shape[1],
+        )
+        graph = helper.make_graph(
+            [conv_node],
+            "conv_transpose_test",
+            inputs=[
+                helper.make_tensor_value_info("x", TensorProto.FLOAT, 
input_shape),
+                helper.make_tensor_value_info("w", TensorProto.FLOAT, 
weight_shape),
+            ]
+            + ([helper.make_tensor_value_info("b", TensorProto.FLOAT, 
bias_shape)] if bias else []),
+            outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
output_shape)],
+        )
 
+        model = helper.make_model(graph, producer_name="conv_transpose_test")
+        check_correctness(model, atol=1e-4)
 
-def test_pow():
-    verify_binary("Pow", [32, 32], [32, 32], [32, 32])
+    # ConvTranspose1D
+    _verify_conv_transpose([3, 4, 32], [4, 4, 3])
+    _verify_conv_transpose([3, 4, 32], [4, 2, 3])  # group=2
+    # ConvTranspose2D
+    _verify_conv_transpose([3, 4, 32, 32], [4, 4, 3, 3])
+    _verify_conv_transpose([3, 4, 32, 32], [4, 2, 3, 3])  # group=2
 
 
-def test_erf():
-    verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT)
-    verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT16)
+def test_pow():
+    verify_binary("Pow", [32, 32], [32, 32], [32, 32])
 
 
 @pytest.mark.parametrize("reverse", [False])
@@ -712,46 +906,6 @@ def test_const():
     check_correctness(model)
 
 
-def test_sub():
-    verify_binary("Sub", [32, 16], [32, 16], [32, 16])
-
-
-def test_min():
-    verify_binary("Min", [32, 16], [32, 16], [32, 16])
-
-
-def test_max():
-    verify_binary("Max", [32, 16], [32, 16], [32, 16])
-
-
-def test_sin():
-    verify_unary("Sin", [32, 16])
-
-
-def test_cos():
-    verify_unary("Cos", [32, 16])
-
-
-def test_identity():
-    verify_unary("Identity", [32, 16])
-
-
-def test_neg():
-    verify_unary("Neg", [32, 16])
-
-
-def test_abs():
-    verify_unary("Abs", [32, 16])
-
-
-def test_log():
-    verify_unary("Log", [32, 16])
-
-
-def test_exp():
-    verify_unary("Exp", [32, 16])
-
-
 def test_instance_norm():
     verify_ternary(
         "InstanceNormalization", [1, 3, 32, 32], [3], [3], [1, 3, 32, 32], 
attrs={"epsilon": 1e-12}
@@ -761,6 +915,11 @@ def test_instance_norm():
     )
 
 
+def test_mean_variance_norm():
+    verify_unary("MeanVarianceNormalization", [1, 3, 32, 32])
+    verify_unary("MeanVarianceNormalization", [1, 3, 32, 32], attrs={"axes": 
(1, 2, 3)})
+
+
 def test_layer_norm():
     layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"], 
["d"], epsilon=1e-12)
 
@@ -1075,9 +1234,36 @@ def test_arg_min_max(in_dtype, axis, keepdims):
     verify_arg_min_max([3, 4, 4], in_dtype, "ArgMin", axis, keepdims)
 
 
[email protected]("axis", [-1, 0, 1])
[email protected]("largest", [True, False])
+def test_topk(axis: int, largest: int):
+    in_shape = [32, 32, 32]
+    k_value = 4
+    out_shape = in_shape
+    out_shape[axis] = k_value
+    k = make_constant_node("k", TensorProto.INT64, [1], [k_value])
+    node = onnx.helper.make_node(
+        "TopK",
+        inputs=["data", "k"],
+        outputs=["values", "indices"],
+        axis=axis,
+        largest=largest,
+    )
+    graph = helper.make_graph(
+        [k, node],
+        "topk_test",
+        inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, 
in_shape)],
+        outputs=[
+            helper.make_tensor_value_info("values", TensorProto.FLOAT, 
out_shape),
+            helper.make_tensor_value_info("indices", TensorProto.INT64, 
out_shape),
+        ],
+    )
+    model = helper.make_model(graph, producer_name="topk_test")
+
+    check_correctness(model)
+
+
 @pytest.mark.parametrize("dynamic", [False, True])
-# TODO(jwfromm) Current approach to dynamic expand is technically not well 
formed. Reenable once fixed.
[email protected]("Produces ill-formed IR")
 def test_expand(dynamic):
     if dynamic:
         # TODO: Support dynamic shape for Expand
@@ -1586,14 +1772,6 @@ def test_range():
     check_correctness(model)
 
 
-def test_less():
-    verify_compare("Less", [32, 32])
-
-
-def test_less_equal():
-    verify_compare("LessOrEqual", [32, 32])
-
-
 def test_batch_norm():
     batch_norm_node = helper.make_node(
         "BatchNormalization", ["x", "s", "bias", "mean", "var"], ["y"], 
epsilon=1e-2
@@ -1811,17 +1989,58 @@ def test_global_average_pool():
     verify_unary("GlobalAveragePool", [1, 3, 32, 32, 32])
 
 
+def test_global_max_pool():
+    verify_unary("GlobalMaxPool", [1, 3, 32])
+    verify_unary("GlobalMaxPool", [1, 3, 32, 32])
+    verify_unary("GlobalMaxPool", [1, 3, 32, 32, 32])
+
+
[email protected]("p", [1, 2, 3])
+def test_global_lp_pool(p: int):
+    verify_unary("GlobalLpPool", [1, 3, 32], attrs={"p": p})
+    verify_unary("GlobalLpPool", [1, 3, 32, 32], attrs={"p": p})
+    verify_unary("GlobalLpPool", [1, 3, 32, 32, 32], attrs={"p": p})
+
+
[email protected]("kernel_shape", [[2, 2], [3, 3]])
[email protected]("pads", [None, [1, 1, 1, 1]])
[email protected]("strides", [None, [2, 2]])
+def test_maxunpool(kernel_shape, pads, strides):
+    input_shape = [16, 3, 16, 16]
+    input_names = ["X", "I"]
+    input_info = [
+        helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape),
+        helper.make_tensor_value_info("I", TensorProto.INT64, input_shape),
+    ]
+
+    attrs = {"kernel_shape": kernel_shape}
+    if pads is not None:
+        attrs["pads"] = pads
+    if strides is not None:
+        attrs["strides"] = strides
+
+    node = helper.make_node("MaxUnpool", inputs=input_names, outputs=["y"], 
**attrs)
+
+    graph = helper.make_graph(
+        [node],
+        "maxunpool_test",
+        inputs=input_info,
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, None)],
+    )
+
+    max_random = int(np.prod(np.array(kernel_shape)))
+    indices = np.random.randint(0, max_random, size=input_shape)
+
+    model = helper.make_model(graph, producer_name="maxunpool_test")
+    check_correctness(model, inputs={"I": indices})
+
+
 def test_flatten():
     verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0})
     verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1})
     verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2})
 
 
-def test_greater():
-    verify_compare("Greater", [32, 32])
-    verify_compare("Greater", [64, 16])
-
-
 def test_onehot():
     one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"], 
["y"], axis=1)
     graph = helper.make_graph(
@@ -1844,8 +2063,189 @@ def test_onehot():
     check_correctness(model, inputs=values)
 
 
-def test_reciprocal():
-    verify_unary("Reciprocal", [3, 32, 32])
[email protected]("axis", [None, 0, 1, -1])
[email protected]("sorted", [0, 1])
+def test_unique(axis: Optional[int], sorted: int):
+    input_shape = [32, 32]
+    if axis is None:
+        output_shape = [-1]
+    else:
+        output_shape = [32, 32]
+        output_shape[axis] = -1
+    unique_node = helper.make_node("Unique", ["x"], ["y"], axis=axis, 
sorted=sorted)
+    graph = helper.make_graph(
+        [unique_node],
+        "unique_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
input_shape)],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
output_shape)],
+    )
+    model = helper.make_model(graph, producer_name="unique_test")
+    check_correctness(model)
+
+
[email protected]("mode", ["DCR", "CRD"])
+def test_depth_to_space(mode: Literal["DCR", "CRD"]):
+    in_shape = [1, 8, 2, 3]
+    out_shape = [1, 2, 4, 6]
+    blocksize = 2
+    node = onnx.helper.make_node(
+        "DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blocksize, 
mode=mode
+    )
+    graph = helper.make_graph(
+        [node],
+        "depth_to_space_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
in_shape)],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
out_shape)],
+    )
+    model = helper.make_model(graph, producer_name="depth_to_space_test")
+
+    check_correctness(model)
+
+
+def test_space_to_depth():
+    in_shape = [1, 2, 4, 6]
+    out_shape = [1, 8, 2, 3]
+    blocksize = 2
+    node = onnx.helper.make_node("SpaceToDepth", inputs=["x"], outputs=["y"], 
blocksize=blocksize)
+    graph = helper.make_graph(
+        [node],
+        "space_to_depth_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
in_shape)],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
out_shape)],
+    )
+    model = helper.make_model(graph, producer_name="space_to_depth_test")
+
+    check_correctness(model)
+
+
+def construct_sequence(input_shape: List[int], num_tensors: int, name: str = 
"sequence"):
+    inputs = [f"data{i}" for i in range(num_tensors)]
+    sequence_construct_node = helper.make_node("SequenceConstruct", inputs, 
[name])
+    graph_inputs = [
+        helper.make_tensor_value_info(f"data{i}", TensorProto.FLOAT, 
input_shape)
+        for i in range(num_tensors)
+    ]
+    return sequence_construct_node, graph_inputs
+
+
+def make_constant_node(name: str, data_type: int, dims: List[int], vals: 
List[int]):
+    return helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=[name],
+        value=helper.make_tensor(name=name, data_type=data_type, dims=dims, 
vals=vals),
+    )
+
+
+def test_sequence_construct():
+    node, graph_inputs = construct_sequence(input_shape=[32, 32], 
num_tensors=2)
+    graph = helper.make_graph(
+        [node],
+        "test_sequence_construct",
+        inputs=graph_inputs,
+        outputs=[helper.make_tensor_sequence_value_info("sequence", 
TensorProto.FLOAT, [32, 32])],
+    )
+    model = helper.make_model(graph, producer_name="test_sequence_construct")
+    check_correctness(model)
+
+
+def test_sequence_empty():
+    sequence_empty_node = helper.make_node("SequenceEmpty", [], ["sequence"])
+    graph = helper.make_graph(
+        [sequence_empty_node],
+        "test_sequence_empty",
+        inputs=[],
+        outputs=[helper.make_tensor_sequence_value_info("sequence", 
TensorProto.FLOAT, [])],
+    )
+    model = helper.make_model(graph, producer_name="test_sequence_empty")
+    check_correctness(model)
+
+
[email protected]("explicit_position", [True, False])
+def test_sequence_erase(explicit_position: bool):
+    seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], 
num_tensors=4)
+    index = make_constant_node("index", TensorProto.INT64, (), [1])
+    node_input = ["sequence", "index"] if explicit_position else ["sequence"]
+    sequence_erase_node = helper.make_node("SequenceErase", node_input, 
["output"])
+    graph = helper.make_graph(
+        [index, seq_node, sequence_erase_node],
+        "test_sequence_erase",
+        inputs=graph_inputs,
+        outputs=[helper.make_tensor_sequence_value_info("output", 
TensorProto.FLOAT, [32, 32])],
+    )
+    model = helper.make_model(graph, producer_name="test_sequence_erase")
+    check_correctness(model)
+
+
[email protected]("explicit_position", [True, False])
+def test_sequence_insert(explicit_position: bool):
+    seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], 
num_tensors=4)
+    index = make_constant_node("index", TensorProto.INT64, (), [0])
+    node_input = ["sequence", "value", "index"] if explicit_position else 
["sequence", "value"]
+    sequence_insert_node = helper.make_node("SequenceInsert", node_input, 
["output"])
+    graph = helper.make_graph(
+        [index, seq_node, sequence_insert_node],
+        "test_sequence_insert",
+        inputs=[*graph_inputs, helper.make_tensor_value_info("value", 
TensorProto.FLOAT, [32, 32])],
+        outputs=[helper.make_tensor_sequence_value_info("output", 
TensorProto.FLOAT, [32, 32])],
+    )
+    model = helper.make_model(graph, producer_name="test_sequence_insert")
+    check_correctness(model)
+
+
[email protected]("new_axis", [0, 1])
+def test_concat_from_sequence(new_axis: Literal[0, 1]):
+    if new_axis == 1:
+        pytest.skip("ConcatFromSequence with new_axis=1 is not supported yet")
+    seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], 
num_tensors=2)
+    concat_from_sequence_node = helper.make_node(
+        "ConcatFromSequence", ["sequence"], ["output"], axis=1
+    )
+    graph = helper.make_graph(
+        [seq_node, concat_from_sequence_node],
+        "test_concat_from_sequence",
+        inputs=graph_inputs,
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[64, 32])],
+    )
+    model = helper.make_model(graph, producer_name="test_concat_from_sequence")
+    check_correctness(model)
+
+
[email protected]("split", [2, [16, 48]])
+def test_split_to_sequence(split):
+    split_to_sequence_node = helper.make_node(
+        "SplitToSequence",
+        ["data", "split"],
+        ["output"],
+        axis=0,
+    )
+    split_shape = [len(split)] if isinstance(split, list) else ()
+    split_node = make_constant_node(
+        "split", TensorProto.INT64, split_shape, [split] if isinstance(split, 
int) else split
+    )
+    graph = helper.make_graph(
+        [split_node, split_to_sequence_node],
+        "test_split_to_sequence",
+        inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [64, 
32])],
+        outputs=[helper.make_tensor_sequence_value_info("output", 
TensorProto.FLOAT, [32, 32])],
+    )
+    model = helper.make_model(graph, producer_name="test_split_to_sequence")
+    check_correctness(model)
+
+
+def test_sequence_at():
+    seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], 
num_tensors=4)
+    index = make_constant_node("index", TensorProto.INT64, (), [1])
+    node_input = ["sequence", "index"]
+    sequence_at_node = helper.make_node("SequenceAt", node_input, ["output"])
+    graph = helper.make_graph(
+        [index, seq_node, sequence_at_node],
+        "test_sequence_at",
+        inputs=graph_inputs,
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[32, 32])],
+    )
+    model = helper.make_model(graph, producer_name="test_sequence_at")
+    check_correctness(model)
 
 
 def test_symbolic_shape_deduction():
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index fcb8727d85..a80b988d06 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -60,7 +60,7 @@ def test_unique(exec_mode):
     result, result_sorted = run_cpu(InputModule, "foo", data, 
exec_mode=exec_mode)
 
     expected_output_sorted, indices = np.unique(data_numpy, return_index=True)
-    expected_output = [data_numpy.flatten()[index] for index in 
sorted(indices, reverse=True)]
+    expected_output = [data_numpy.flatten()[index] for index in 
sorted(indices)]
 
     np.testing.assert_array_equal(expected_output_sorted, 
result_sorted.numpy())
     np.testing.assert_array_equal(expected_output, result.numpy())
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index d03d48968d..12436cf802 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -204,6 +204,53 @@ def test_conv1d_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_conv1d_transpose():
+    # fmt: off
+    @I.ir_module
+    class Conv1dTranspose:
+        @R.function
+        def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((128, 16, 
3), "float32")):
+            gv = R.nn.conv1d_transpose(x, w, strides=2, padding=1, dilation=1, 
output_padding=1, groups=8)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def conv1d_transpose(x: T.Buffer((T.int64(2), T.int64(128), 
T.int64(28)), "float32"), w: T.Buffer((T.int64(128), T.int64(16), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            data_dilate = T.alloc_buffer((T.int64(2), T.int64(128), 
T.int64(55)))
+            data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58)))
+            kernel = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3)))
+            for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(55)):
+                with T.block("data_dilate"):
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    data_dilate[v_i0, v_i1, v_i2] = T.if_then_else(v_i2 % 
T.int64(2) == T.int64(0), x[v_i0, v_i1, v_i2 // T.int64(2)], T.float32(0.0))
+            for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(58)):
+                with T.block("data_pad"):
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    data_pad[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= 
v_i2 and v_i2 < T.int64(56), data_dilate[v_i0, v_i1, v_i2 - T.int64(1)], 
T.float32(0.0))
+            for o, i, w_1 in T.grid(T.int64(16), T.int64(128), T.int64(3)):
+                with T.block("kernel"):
+                    v_o, v_i, v_w = T.axis.remap("SSS", [o, i, w_1])
+                    kernel[v_o, v_i, v_w] = w[v_i, v_o, T.int64(2) - v_w]
+            for b, c, w_1, dc, dw in T.grid(T.int64(2), T.int64(128), 
T.int64(56), T.int64(16), T.int64(3)):
+                with T.block("compute"):
+                    v_b, v_c, v_w, v_dc, v_dw = T.axis.remap("SSSRR", [b, c, 
w_1, dc, dw])
+                    with T.init():
+                        compute[v_b, v_c, v_w] = T.float32(0.0)
+                    compute[v_b, v_c, v_w] = compute[v_b, v_c, v_w] + 
data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_w + v_dw] * kernel[v_c 
% T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dw]
+
+        @R.function
+        def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((128, 
16, 3), dtype="float32")) -> R.Tensor((2, 128, 56), dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(cls.conv1d_transpose, (x, w), 
out_sinfo=R.Tensor((2, 128, 56), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Conv1dTranspose)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_conv2d():
     # fmt: off
     @tvm.script.ir_module


Reply via email to