This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9fdb86d3f6 [Relax][ONNX] Expand op support for ONNX frontend (#17427)
9fdb86d3f6 is described below
commit 9fdb86d3f6bccc41a772328b5b0442908bc9f9a9
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Oct 3 22:36:55 2024 +0800
[Relax][ONNX] Expand op support for ONNX frontend (#17427)
* [Relax][ONNX] Expand op support for ONNX frontend
This PR adds a variety of ONNX ops to the Relax frontend, including:
- Acos
- Acosh
- And
- Asin
- Asinh
- Atan
- Atanh
- BitwiseAnd
- BitwiseOr
- BitwiseXor
- Ceil
- ConcatFromSequence
- ConvTranspose
- Cosh
- DepthToSpace
- FastGelu
- Floor
- GlobalLpPool
- GlobalMaxPool
- GreaterOrEqual
- IsInf
- IsNaN
- LeakyRelu
- LogSoftmax
- MaxUnpool
- Mean
- MeanVarianceNormalization
- Mish
- Or
- PRelu
- Round
- Scatter
- ScatterElements
- Selu
- SequenceAt
- SequenceConstruct
- SequenceEmpty
- SequenceErase
- SequenceInsert
- SequenceLength
- Shrink
- Sinh
- Size
- Softplus
- Softsign
- SpaceToDepth
- SplitToSequence
- Tan
- ThresholdedRelu
- TopK
- Unique
- Xor
Also remains a few ops that are not supported yet, see the commented out
ops in the
ONNX frontend.
* lint
* lint
* lint
* update for ci
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 1302 ++++++++++++++++----
python/tvm/relax/op/set.py | 8 +-
python/tvm/relax/transform/legalize_ops/nn.py | 9 +-
tests/python/relax/test_frontend_onnx.py | 664 ++++++++--
tests/python/relax/test_relax_operators.py | 2 +-
.../python/relax/test_transform_legalize_ops_nn.py | 47 +
6 files changed, 1617 insertions(+), 415 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 462d1cf92c..5777f51fe2 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -34,14 +34,15 @@ If this fails, there may still be dynamic operations in the
model.
Not all TVM kernels currently support dynamic shapes, please file an issue on
github.com/apache/tvm/issues if you hit an error with dynamic kernels.
"""
+import math
import warnings
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as _np
import onnx.onnx_ml_pb2
import tvm
-from tvm import relax, tir, topi
+from tvm import TVMError, relax, tir, topi
from tvm.ir import IRModule
from tvm.ir.supply import NameSupply
from tvm.tir.generic import cast
@@ -236,28 +237,176 @@ class MatMul(OnnxOpConverter):
return relax.op.matmul(inputs[0], inputs[1])
-class Div(OnnxOpConverter):
- """Converts an onnx Div node into an equivalent Relax expression."""
+class BinaryBase(OnnxOpConverter):
+ """Converts an onnx BinaryBase node into an equivalent Relax expression."""
+
+ numpy_op: Callable = None
+ relax_op: Callable = None
@classmethod
- def _impl_v14(cls, bb, inputs, attr, params):
+ def _impl_v1(cls, bb, inputs, attr, params):
+ if cls.numpy_op is None or cls.relax_op is None:
+ raise ValueError("Numpy and Relax operators must be defined for
BinaryBase.")
if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = inputs[0].data.numpy() / inputs[1].data.numpy()
+ output = cls.numpy_op( # pylint: disable=not-callable
+ inputs[0].data.numpy(), inputs[1].data.numpy()
+ )
return relax.const(output, inputs[0].struct_info.dtype)
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
x = (
- int(inputs[0].value)
+ _np.array(inputs[0].value)
if isinstance(inputs[0], relax.PrimValue)
else inputs[0].data.numpy()
)
y = (
- int(inputs[1].value)
+ _np.array(inputs[0].value)
if isinstance(inputs[1], relax.PrimValue)
else inputs[1].data.numpy()
)
- return relax.PrimValue(int(x / y))
+ return relax.PrimValue(cls.numpy_op(x, y)) # pylint:
disable=not-callable
+
+ return cls.relax_op(inputs[0], inputs[1]) # pylint:
disable=not-callable
+
+
+class Add(BinaryBase):
+ """Converts an onnx Add node into an equivalent Relax expression."""
+
+ numpy_op = _np.add
+ relax_op = relax.op.add
+
+
+class Sub(BinaryBase):
+ """Converts an onnx Sub node into an equivalent Relax expression."""
+
+ numpy_op = _np.subtract
+ relax_op = relax.op.subtract
+
+
+class Mul(BinaryBase):
+ """Converts an onnx Mul node into an equivalent Relax expression."""
+
+ numpy_op = _np.multiply
+ relax_op = relax.op.multiply
+
+
+class Div(BinaryBase):
+ """Converts an onnx Div node into an equivalent Relax expression."""
+
+ numpy_op = _np.divide
+ relax_op = relax.op.divide
+
+
+class Pow(BinaryBase):
+ """Converts an onnx Pow node into an equivalent Relax expression."""
+
+ numpy_op = _np.power
+ relax_op = relax.op.power
+
+
+class And(BinaryBase):
+ """Converts an onnx And node into an equivalent Relax expression."""
+
+ numpy_op = _np.logical_and
+ relax_op = relax.op.logical_and
- return relax.op.divide(inputs[0], inputs[1])
+
+class Or(BinaryBase):
+ """Converts an onnx Or node into an equivalent Relax expression."""
+
+ numpy_op = _np.logical_or
+ relax_op = relax.op.logical_or
+
+
+class Xor(BinaryBase):
+ """Converts an onnx Xor node into an equivalent Relax expression."""
+
+ numpy_op = _np.logical_xor
+ relax_op = relax.op.logical_xor
+
+
+class Less(BinaryBase):
+ """Converts an onnx Less node into an equivalent Relax expression."""
+
+ numpy_op = _np.less
+ relax_op = relax.op.less
+
+
+class LessOrEqual(BinaryBase):
+ """Converts an onnx LessEqual node into an equivalent Relax expression."""
+
+ numpy_op = _np.less_equal
+ relax_op = relax.op.less_equal
+
+
+class Greater(BinaryBase):
+ """Converts an onnx Greater node into an equivalent Relax expression."""
+
+ numpy_op = _np.greater
+ relax_op = relax.op.greater
+
+
+class GreaterOrEqual(BinaryBase):
+ """Converts an onnx GreaterEqual node into an equivalent Relax
expression."""
+
+ numpy_op = _np.greater_equal
+ relax_op = relax.op.greater_equal
+
+
+class Equal(OnnxOpConverter):
+ """Converts an onnx Equal node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = inputs[0].data.numpy() == inputs[1].data.numpy()
+ return relax.const(output, output.dtype)
+ elif all([isinstance(inp, (relax.Constant, relax.ShapeExpr)) for inp
in inputs]):
+ lhs = get_prim_expr_list(inputs[0])
+ rhs = get_prim_expr_list(inputs[1])
+ if len(lhs) != len(rhs):
+ raise ValueError("Cannot compare two tensors with different
shapes")
+ output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)]
+ return relax.const(output, "bool")
+ return relax.op.equal(inputs[0], inputs[1])
+
+
+class BitwiseBase(BinaryBase):
+ """Converts an onnx BitwiseBase node into an equivalent Relax
expression."""
+
+ @classmethod
+ def base_impl(cls, bb, inputs, attr, params, py_func, relax_op):
+ valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64"]
+ for num, inp in enumerate(inputs):
+ if inp.struct_info.dtype not in valid_types:
+ raise ValueError(
+ f"Bitwise operations expect all inputs to have integer
types, "
+ f"got {inp.struct_info.dtype} for input {num}"
+ )
+ return BinaryBase.base_impl(bb, inputs, attr, params, py_func,
relax_op)
+
+
+class BitwiseAnd(BitwiseBase):
+ """Converts an onnx BitwiseAnd node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v18(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y,
relax.op.bitwise_and)
+
+
+class BitwiseOr(BitwiseBase):
+ """Converts an onnx BitwiseOr node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v18(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y,
relax.op.bitwise_or)
+
+
+class BitwiseXor(BitwiseBase):
+ """Converts an onnx BitwiseXor node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v18(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y,
relax.op.bitwise_xor)
class Sigmoid(OnnxOpConverter):
@@ -277,6 +426,15 @@ class Softmax(OnnxOpConverter):
return relax.op.nn.softmax(inputs[0], axis=axis)
+class LogSoftmax(OnnxOpConverter):
+ """Converts an onnx LogSoftmax node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", -1)
+ return relax.op.nn.log_softmax(inputs[0], axis=axis)
+
+
class Transpose(OnnxOpConverter):
"""Converts an onnx Transpose node into an equivalent Relax expression."""
@@ -375,67 +533,6 @@ class Concat(OnnxOpConverter):
return relax.op.concat(inputs, axis=axis)
-class Add(OnnxOpConverter):
- """Convert an onnx Add node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = inputs[0].data.numpy() + inputs[1].data.numpy()
- return relax.const(output, output.dtype)
- # If primvalues are involved, handle them directly.
- if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
- x = (
- int(inputs[0].value)
- if isinstance(inputs[0], relax.PrimValue)
- else inputs[0].data.numpy()
- )
- y = (
- int(inputs[1].value)
- if isinstance(inputs[1], relax.PrimValue)
- else inputs[1].data.numpy()
- )
- return relax.PrimValue(int(x + y))
- return relax.op.add(inputs[0], inputs[1])
-
-
-class Sum(OnnxOpConverter):
- """Convert an onnx Sum node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
- for in_index in range(len(inputs) - 1):
- inputs[in_index + 1] = relax.op.add(inputs[in_index],
inputs[in_index + 1])
-
- return inputs[len(inputs) - 1]
-
-
-class Mul(OnnxOpConverter):
- """Convert an onnx Mul node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- # When all inputs are constant, directly multiply.
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = inputs[0].data.numpy() * inputs[1].data.numpy()
- return relax.const(output, output.dtype)
- # If primvalues are involved, handle them directly.
- if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
- x = (
- int(inputs[0].value)
- if isinstance(inputs[0], relax.PrimValue)
- else inputs[0].data.numpy()
- )
- y = (
- int(inputs[1].value)
- if isinstance(inputs[1], relax.PrimValue)
- else inputs[1].data.numpy()
- )
- return relax.PrimValue(int(x * y))
-
- return relax.op.multiply(inputs[0], inputs[1])
-
-
class Cast(OnnxOpConverter):
"""Convert an onnx Cast node into an equivalent Relax expression."""
@@ -482,8 +579,38 @@ class Gather(OnnxOpConverter):
shape_val = data[np_index]
return relax.PrimValue(shape_val)
- # TODO(jwfromm) Make relax.take work with other indices shape.
- return bb.emit_te(topi.take, data, indices, axis)
+ return relax.op.take(data, indices, axis)
+
+
+class Scatter(OnnxOpConverter):
+ """Convert an onnx Scatter node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 0)
+ return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2],
axis=axis)
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ raise ValueError("Scatter is deprecated in ONNX 11")
+
+
+class ScatterElements(OnnxOpConverter):
+ """Convert an onnx ScatterElements node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 0)
+ return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2],
axis=axis)
+
+
+class Size(OnnxOpConverter):
+ """Convert an onnx Size node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ # TODO(tvm-team): add native support for size op
+ return
relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))
class Gemm(OnnxOpConverter):
@@ -542,29 +669,6 @@ class Reshape(OnnxOpConverter):
return out
-class Gelu(OnnxOpConverter):
- """Operator converter for Gelu from Microsoft onnxruntime contrib opset.
-
- gelu(x) = 0.5x(1 + erf(x/sqrt(2)))
- """
-
- @classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
- return relax.op.nn.gelu(inputs[0])
-
-
-class BiasGelu(OnnxOpConverter):
- """Operator converter for BiasGelu from Microsoft onnxruntime contrib
opset.
-
- bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
- """
-
- @classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
- inp = relax.op.add(inputs[0], inputs[1])
- return relax.op.nn.gelu(inp)
-
-
class Where(OnnxOpConverter):
"""Convert an onnx Where node into an equivalent Relax expression."""
@@ -605,24 +709,6 @@ class Clip(OnnxOpConverter):
return results
-class Equal(OnnxOpConverter):
- """Converts an onnx Equal node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = inputs[0].data.numpy() == inputs[1].data.numpy()
- return relax.const(output, output.dtype)
- elif all([isinstance(inp, (relax.Constant, relax.ShapeExpr)) for inp
in inputs]):
- lhs = get_prim_expr_list(inputs[0])
- rhs = get_prim_expr_list(inputs[1])
- if len(lhs) != len(rhs):
- raise ValueError("Cannot compare two tensors with different
shapes")
- output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)]
- return relax.const(output, "bool")
- return relax.op.equal(inputs[0], inputs[1])
-
-
class Shape(OnnxOpConverter):
"""Converts an onnx Equal node into an equivalent Relax expression."""
@@ -643,22 +729,6 @@ class Shape(OnnxOpConverter):
return data_info.shape
-class Tanh(OnnxOpConverter):
- """Converts an onnx Tanh node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- return relax.op.tanh(inputs[0])
-
-
-class Sqrt(OnnxOpConverter):
- """Converts an onnx Sqrt node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- return relax.op.sqrt(inputs[0])
-
-
class Trilu(OnnxOpConverter):
"""Given a 2-D matrix or batches of 2-D matrices, returns the upper or
lower triangular part of the tensor(s)
@@ -691,12 +761,157 @@ class Relu(OnnxOpConverter):
return relax.op.nn.relu(inputs[0])
-class Pow(OnnxOpConverter):
- """Converts an onnx Pow node into an equivalent Relax expression."""
+class Elu(OnnxOpConverter):
+ """Converts an onnx Elu node into an equivalent Relax expression."""
@classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- return relax.op.power(inputs[0], inputs[1])
+ def _impl_v1(cls, bb, inputs, attr, params):
+ alpha = float(attr.get("alpha", 1.0))
+ return relax.expr.const(-alpha) * relax.op.nn.relu(
+ relax.expr.const(1.0) - relax.op.exp(inputs[0])
+ ) + relax.op.nn.relu(inputs[0])
+
+
+class Selu(OnnxOpConverter):
+ """Converts an onnx Selu node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ alpha = attr.get("alpha", 1.67326319217681884765625)
+ gamma = attr.get("gamma", 1.05070102214813232421875)
+ return relax.const(gamma) * (
+ relax.const(-alpha) * relax.op.nn.relu(relax.const(1.0) -
relax.op.exp(inputs[0]))
+ + relax.op.nn.relu(inputs[0])
+ )
+
+
+class Mish(OnnxOpConverter):
+ """Converts an onnx Mish node into an equivalent Relax expression.
+
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
+ """
+
+ @classmethod
+ def _impl_v18(cls, bb, inputs, attr, params):
+ dtype = inputs[0].checked_type.dtype
+ return inputs[0] * relax.op.tanh(
+ relax.op.log(relax.const(1.0, dtype) + relax.op.exp(inputs[0]))
+ )
+
+
+class PRelu(OnnxOpConverter):
+ """Converts an onnx PRelu node into an equivalent Relax expression.
+
+ f(x) = slope * x for x < 0, x for x >= 0
+ """
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ x = inputs[0]
+ slope = inputs[1]
+ # TODO(tvm-team): Should add a new op for this.
+ return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope)
+
+
+class ThresholdedRelu(OnnxOpConverter):
+ """Converts an onnx ThresholdedRelu node into an equivalent Relax
expression.
+
+ f(x) = x for x > alpha, 0 otherwise
+ """
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ x = inputs[0]
+ alpha = attr.get("alpha", 1.0)
+ return relax.op.greater(x, relax.const(alpha)).astype("float32") * x
+
+
+class LeakyRelu(OnnxOpConverter):
+ """Converts an onnx LeakyRelu node into an equivalent Relax expression.
+
+ f(x) = x for x > 0, alpha * x otherwise
+ """
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ x = inputs[0]
+ alpha = attr.get("alpha", 0.01)
+ return relax.op.nn.leakyrelu(x, alpha)
+
+
+class Gelu(OnnxOpConverter):
+ """Operator converter for Gelu from Microsoft onnxruntime contrib opset.
+
+ gelu(x) = 0.5x(1 + erf(x/sqrt(2)))
+ """
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return relax.op.nn.gelu(inputs[0])
+
+
+class FastGelu(OnnxOpConverter):
+ """Operator converter for FastGelu from Microsoft onnxruntime contrib
opset.
+
+ fast_gelu(x) = 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3)))
+ = 0.5x(1 + tanh((sqrt(2/pi)x + 0.044715(sqrt(2/pi)x^3)))
+ = 0.5x(1 + tanh(c1 * x + c2 * x^3)))
+ , where
+ c1 = sqrt(2/pi)
+ c2 = 0.044715 * sqrt(2/pi)
+ """
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ if inputs[1]:
+ bias = inputs[1]
+ bias_shape = bias.struct_info.shape
+ assert len(bias_shape) == 1, "bias term must be a 1D tensor"
+ x += bias
+
+ # Declare consts
+ const_dtype = x.struct_info.dtype
+ half = relax.const(0.5, dtype=const_dtype)
+ one = relax.const(1.0, dtype=const_dtype)
+ const1 = relax.const(math.sqrt(2 / math.pi), dtype=const_dtype)
+ const2 = relax.const(0.044715 * math.sqrt(2 / math.pi),
dtype=const_dtype)
+
+ # Compute FastGelu
+ term1 = relax.op.multiply(half, x)
+ term2 = relax.op.multiply(const1, x)
+ term3 = relax.op.multiply(const2, relax.op.power(x, relax.const(3,
const_dtype)))
+ tanh = relax.op.tanh(relax.op.add(term2, term3))
+ return relax.op.multiply(term1, relax.op.add(one, tanh))
+
+
+class BiasGelu(OnnxOpConverter):
+ """Operator converter for BiasGelu from Microsoft onnxruntime contrib
opset.
+
+ bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
+ """
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ inp = relax.op.add(inputs[0], inputs[1])
+ return relax.op.nn.gelu(inp)
+
+
+class Shrink(OnnxOpConverter):
+ """Converts an onnx Shrink node into an equivalent Relax expression.
+
+ f(x) = x + bias if x > lambd, x - bias if x < -lambd, 0 otherwise
+ """
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ x = inputs[0]
+ dtype = x.struct_info.dtype
+ lambd = relax.const(attr.get("lambd", 0.5), dtype)
+ bias = relax.const(attr.get("bias", 0.0), dtype)
+ zeros = relax.op.zeros_like(x)
+ return relax.op.where(x > lambd, x - bias, zeros) + relax.op.where(
+ x < -lambd, x + bias, zeros
+ )
class Conv(OnnxOpConverter):
@@ -730,21 +945,55 @@ class Conv(OnnxOpConverter):
weight=inputs[1],
strides=attr.get("strides", 1),
padding=attr.get("pads", 0),
- dilation=attr.get("dilation", 1),
+ dilation=attr.get("dilations", 1),
groups=attr.get("group", 1),
data_layout=data_layout,
kernel_layout=kernel_layout,
)
)
if inputs[2] is not None:
- bias = relax.op.reshape(
- inputs[2],
- [1, -1]
- + [
- 1,
- ]
- * (ndim - 2),
- )
+ bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2))
+ conv_out = relax.op.add(conv_out, bias)
+
+ return conv_out
+
+
+class ConvTranspose(OnnxOpConverter):
+ """Converts an onnx ConvTranspose node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ if hasattr(inputs[0].struct_info, "ndim"):
+ ndim = inputs[0].struct_info.ndim
+ else:
+ ndim = len(inputs[0].struct_info.shape)
+
+ if ndim == 3:
+ op = relax.op.nn.conv1d_transpose
+ data_layout = "NCW"
+ kernel_layout = "IOW"
+ elif ndim == 4:
+ op = relax.op.nn.conv2d_transpose
+ data_layout = "NCHW"
+ kernel_layout = "IOHW"
+ elif ndim == 5:
+ raise NotImplementedError("Relax ConvTranspose3d not supported
yet")
+ else:
+ raise NotImplementedError("Ndim > 5 not supported for
convolution.")
+
+ conv_out = op(
+ data=inputs[0],
+ weight=inputs[1],
+ strides=attr.get("strides", 1),
+ padding=attr.get("pads", 0),
+ dilation=attr.get("dilations", 1),
+ groups=attr.get("group", 1),
+ data_layout=data_layout,
+ kernel_layout=kernel_layout,
+ )
+
+ if inputs[2] is not None:
+ bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2))
conv_out = relax.op.add(conv_out, bias)
return conv_out
@@ -839,17 +1088,6 @@ class ConstantOfShape(OnnxOpConverter):
return relax.op.broadcast_to(relax.const(value, dtype), shape)
-class Sub(OnnxOpConverter):
- """Converts an onnx Sub node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = inputs[0].data.numpy() - inputs[1].data.numpy()
- return relax.const(output, output.dtype)
- return relax.op.subtract(inputs[0], inputs[1])
-
-
class Sin(OnnxOpConverter):
"""Converts an onnx Sin node into an equivalent Relax expression."""
@@ -858,6 +1096,14 @@ class Sin(OnnxOpConverter):
return relax.op.sin(inputs[0])
+class Sinh(OnnxOpConverter):
+ """Converts an onnx Sinh node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ return relax.op.sinh(inputs[0])
+
+
class Cos(OnnxOpConverter):
"""Converts an onnx Cos node into an equivalent Relax expression."""
@@ -866,6 +1112,78 @@ class Cos(OnnxOpConverter):
return relax.op.cos(inputs[0])
+class Cosh(OnnxOpConverter):
+ """Converts an onnx Cosh node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ return relax.op.cosh(inputs[0])
+
+
+class Tan(OnnxOpConverter):
+ """Converts an onnx Tan node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v7(cls, bb, inputs, attr, params):
+ return relax.op.tan(inputs[0])
+
+
+class Tanh(OnnxOpConverter):
+ """Converts an onnx Tanh node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v7(cls, bb, inputs, attr, params):
+ return relax.op.tanh(inputs[0])
+
+
+class Acos(OnnxOpConverter):
+ """Converts an onnx Acos node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v7(cls, bb, inputs, attr, params):
+ return relax.op.acos(inputs[0])
+
+
+class Acosh(OnnxOpConverter):
+ """Converts an onnx Acosh node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ return relax.op.acosh(inputs[0])
+
+
+class Asin(OnnxOpConverter):
+ """Converts an onnx Asin node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v7(cls, bb, inputs, attr, params):
+ return relax.op.asin(inputs[0])
+
+
+class Asinh(OnnxOpConverter):
+ """Converts an onnx Asinh node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ return relax.op.asinh(inputs[0])
+
+
+class Atan(OnnxOpConverter):
+ """Converts an onnx Atan node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v7(cls, bb, inputs, attr, params):
+ return relax.op.atan(inputs[0])
+
+
+class Atanh(OnnxOpConverter):
+ """Converts an onnx Atanh node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ return relax.op.atanh(inputs[0])
+
+
class Neg(OnnxOpConverter):
"""Converts an onnx Neg node into an equivalent Relax expression."""
@@ -877,47 +1195,121 @@ class Neg(OnnxOpConverter):
return relax.op.negative(inputs[0])
-class Abs(OnnxOpConverter):
- """Converts an onnx Abs node into an equivalent Relax expression."""
+class Abs(OnnxOpConverter):
+ """Converts an onnx Abs node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if isinstance(inputs[0], relax.Constant):
+ output = _np.abs(inputs[0].data.numpy())
+ return relax.const(output, output.dtype)
+ return relax.op.abs(inputs[0])
+
+
+class Reciprocal(OnnxOpConverter):
+ """Converts an onnx Reciprocal node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ input_dtype = inputs[0].struct_info.dtype
+ return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0])
+
+
+class Floor(OnnxOpConverter):
+ """Converts an onnx Floor node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return relax.op.floor(inputs[0])
+
+
+class Ceil(OnnxOpConverter):
+ """Converts an onnx Ceil node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return relax.op.ceil(inputs[0])
+
+
+class Round(OnnxOpConverter):
+ """Converts an onnx Round node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return relax.op.round(inputs[0])
+
+
+class IsInf(OnnxOpConverter):
+ """Converts an onnx IsInf node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v10(cls, bb, inputs, attr, params):
+ return relax.op.isinf(inputs[0])
+
+
+class IsNaN(OnnxOpConverter):
+ """Converts an onnx IsNaN node into an equivalent Relax expression."""
@classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- if isinstance(inputs[0], relax.Constant):
- output = _np.abs(inputs[0].data.numpy())
- return relax.const(output, output.dtype)
- return relax.op.abs(inputs[0])
+ def _impl_v9(cls, bb, inputs, attr, params):
+ return relax.op.isnan(inputs[0])
-class Min(OnnxOpConverter):
- """Converts an onnx Min node into an equivalent Relax expression."""
+class Sqrt(OnnxOpConverter):
+ """Converts an onnx Sqrt node into an equivalent Relax expression."""
@classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return relax.op.sqrt(inputs[0])
+
+
+class MultiInputBase(OnnxOpConverter):
+ """Converts an onnx MultiInputBase node into an equivalent Relax
expression."""
+
+ numpy_op: Callable = None
+ relax_op: Callable = None
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ if cls.numpy_op is None or cls.relax_op is None:
+ raise NotImplementedError("numpy_op and relax_op must be defined
for MultiInputBase")
if all([isinstance(inp, relax.Constant) for inp in inputs]):
np_inputs = [inp.data.numpy() for inp in inputs]
- output = _np.minimum(*np_inputs)
+ output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable
return relax.const(output, output.dtype)
# Expand inputs, stack them, then perform minimum over the new axis.
inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in
inputs]
stacked_tensor = relax.op.concat(inputs, axis=0)
- return relax.op.min(stacked_tensor, axis=0)
+ return cls.relax_op(stacked_tensor, axis=0) # pylint:
disable=not-callable
+
+
+class Min(MultiInputBase):
+ """Converts an onnx Min node into an equivalent Relax expression."""
+
+ numpy_op = _np.min
+ relax_op = relax.op.min
-class Max(OnnxOpConverter):
+class Max(MultiInputBase):
"""Converts an onnx Max node into an equivalent Relax expression."""
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- np_inputs = [inp.data.numpy() for inp in inputs]
- output = _np.maximum(*np_inputs)
- return relax.const(output, output.dtype)
+ numpy_op = _np.max
+ relax_op = relax.op.max
- # Expand inputs, stack them, then perform maximum over the new axis.
- inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in
inputs]
- stacked_tensor = relax.op.concat(inputs, axis=0)
- return relax.op.max(stacked_tensor, axis=0)
+
+class Mean(MultiInputBase):
+ """Converts an onnx Mean node into an equivalent Relax expression."""
+
+ numpy_op = _np.mean
+ relax_op = relax.op.mean
+
+
+class Sum(MultiInputBase):
+ """Converts an onnx Sum node into an equivalent Relax expression."""
+
+ numpy_op = _np.sum
+ relax_op = relax.op.sum
class Log(OnnxOpConverter):
@@ -956,26 +1348,22 @@ class Exp(OnnxOpConverter):
return relax.op.exp(data)
-class Less(OnnxOpConverter):
- """Converts an onnx Less node into an equivalent Relax expression."""
+class Softplus(OnnxOpConverter):
+ """Converts an onnx Softplus node into an equivalent Relax expression."""
@classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = _np.less(inputs[0].data.numpy(), inputs[1].data.numpy())
- return relax.const(output, output.dtype)
- return relax.op.less(inputs[0], inputs[1])
+ def _impl_v1(cls, bb, inputs, attr, params):
+ dtype = inputs[0].struct_info.dtype
+ return relax.op.log(relax.op.exp(inputs[0]) + relax.const(1,
dtype=dtype))
-class LessOrEqual(OnnxOpConverter):
- """Converts an onnx LessOrEqual node into an equivalent Relax
expression."""
+class Softsign(OnnxOpConverter):
+ """Converts an onnx Softsign node into an equivalent Relax expression."""
@classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = _np.less_equal(inputs[0].data.numpy(),
inputs[1].data.numpy())
- return relax.const(output, output.dtype)
- return relax.op.less_equal(inputs[0], inputs[1])
+ def _impl_v1(cls, bb, inputs, attr, params):
+ dtype = inputs[0].struct_info.dtype
+ return inputs[0] / (relax.op.abs(inputs[0]) + relax.const(1,
dtype=dtype))
class Split(OnnxOpConverter):
@@ -1456,6 +1844,20 @@ class BatchNormalization(OnnxOpConverter):
)
+class MeanVarianceNormalization(OnnxOpConverter):
+ """Converts an onnx MeanVarianceNormalization node into an equivalent
Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis = attr.get("axes", (0, 2, 3))
+ data_mean = relax.op.mean(data, axis=axis, keepdims=True)
+ data_mean_squared = relax.op.power(data_mean, relax.const(2,
dtype="float32"))
+ data_squared = relax.op.power(data, relax.const(2, dtype="float32"))
+ data_squared_mean = relax.op.mean(data_squared, axis=axis,
keepdims=True)
+ return (data - data_mean) / relax.op.sqrt(data_squared_mean -
data_mean_squared)
+
+
class Pool(OnnxOpConverter):
"""A helper class for pool op converters."""
@@ -1557,16 +1959,79 @@ class GlobalAveragePool(OnnxOpConverter):
@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
rank = len(inputs[0].struct_info.shape)
- if rank == 3:
- return relax.op.nn.adaptive_avg_pool1d(inputs[0], 1)
- elif rank == 4:
- return relax.op.nn.adaptive_avg_pool2d(inputs[0], 1)
- elif rank == 5:
- return relax.op.nn.adaptive_avg_pool3d(inputs[0], 1)
- raise NotImplementedError(
- "Global average pooling is only implemented for 1D, 2D, and 3D
kernels, got %dD."
- % (rank - 2)
+ axes = list(range(2, rank))
+ return relax.op.mean(inputs[0], axis=axes, keepdims=True)
+
+
+class GlobalMaxPool(OnnxOpConverter):
+ """Converts an onnx GlobalMaxPool node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ rank = len(inputs[0].struct_info.shape)
+ axes = list(range(2, rank))
+ return relax.op.max(inputs[0], axis=axes, keepdims=True)
+
+
+class GlobalLpPool(OnnxOpConverter):
+ """Converts an onnx GlobalLpPool node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v2(cls, bb, inputs, attr, params):
+ p = attr.get("p", 2.0)
+ dtype = inputs[0].struct_info.dtype
+ rank = len(inputs[0].struct_info.shape)
+ axes = list(range(2, rank))
+ x_abs = relax.op.abs(inputs[0])
+ x_p = relax.op.power(x_abs, relax.const(p, dtype=dtype))
+ x_sum = relax.op.sum(x_p, axes, keepdims=True)
+ return relax.op.power(x_sum, relax.const(1.0 / p, dtype=dtype))
+
+
+class MaxUnpool(OnnxOpConverter):
+ """Converts an onnx MaxUnpool node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ indices = inputs[1]
+ output_shape = inputs[2]
+ kernel_shape = attr.get("kernel_shape")
+ pads = attr.get("pads", [0] * len(kernel_shape) * 2)
+ strides = attr.get("strides", [1] * len(kernel_shape))
+
+ multiplier = _np.concatenate([[1, 1], list(strides)])
+ shape = [v.value for v in data.struct_info.shape]
+ total_output_shape = multiplier * shape
+ # Add extra dimensions from kernel size and stride mismatch
+ total_output_shape += _np.concatenate([[0, 0], list(kernel_shape)],
axis=0)
+ total_output_shape -= _np.concatenate([[0, 0], list(strides)], axis=0)
+
+ if output_shape is not None:
+ total_output_shape = output_shape
+
+ elif pads is not None:
+ # Get pads in the proper format for relay.
+ pads = _np.concatenate([[0, 0, 0, 0], list(pads)], axis=0)
+ pads = _np.reshape(pads, [-1, 2])
+ # Compute the total padding per axis.
+ total_pad = _np.sum(pads, axis=-1)
+ # Reversing maxpool means that padding actually makes our output
smaller.
+ total_output_shape = total_output_shape - total_pad
+
+ # Create a tensor of zeros then scatter our data through it.
+ relax_shape = relax.ShapeExpr(total_output_shape.tolist())
+ zeros_tensor = bb.emit(relax.op.zeros(relax_shape,
data.struct_info.dtype))
+ # We need to flatten all our tensors before scattering.
+ flat_tensor = relax.op.scatter_elements(
+ relax.op.reshape(zeros_tensor, [-1]),
+ relax.op.reshape(indices, [-1]),
+ relax.op.reshape(data, [-1]),
+ axis=0,
)
+ # Reshape our flattened data back to normal.
+ output = relax.op.reshape(flat_tensor, relax_shape)
+ return output
class Flatten(OnnxOpConverter):
@@ -1799,6 +2264,32 @@ class ArgMin(OnnxOpConverter):
return relax.op.argmin(data, axis, keepdims)
+class TopK(OnnxOpConverter):
+ """Converts an onnx TopK node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ k = inputs[1]
+ if not isinstance(k, relax.Constant):
+ raise ValueError("TopK k must be a constant")
+ k = int(k.data.numpy())
+ axis = attr.get("axis", -1)
+ largest = attr.get("largest", 1)
+ sorted = attr.get("sorted", 1)
+ if sorted != 1:
+ raise ValueError("TopK sorted must be 1 for Relax frontend")
+
+ return relax.op.topk(data, k, axis, ret_type="both", largest=largest)
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ k = attr.get("k", 1)
+ axis = attr.get("axis", -1)
+ return relax.op.topk(data, k, axis, ret_type="both")
+
+
class SkipLayerNormalization(OnnxOpConverter):
"""Converts a microsoft contrib SkipLayerNormalization node into a Relax
expression."""
@@ -1871,26 +2362,6 @@ class EmbedLayerNormalization(OnnxOpConverter):
return relax.Tuple([ln, mask_index])
-class Greater(OnnxOpConverter):
- """Converts an onnx Greater node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = _np.greater(inputs[0].data.numpy(),
inputs[1].data.numpy())
- return relax.const(output, output.dtype)
- return relax.op.greater(inputs[0], inputs[1])
-
-
-class Reciprocal(OnnxOpConverter):
- """Converts an onnx Reciprocal node into an equivalent Relax expression."""
-
- @classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- input_dtype = inputs[0].struct_info.dtype
- return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0])
-
-
class OneHot(OnnxOpConverter):
"""Converts an onnx OneHot node into an equivalent Relax expression."""
@@ -1909,15 +2380,16 @@ class OneHot(OnnxOpConverter):
return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth,
axis, dtype)
-class Elu(OnnxOpConverter):
- """Converts an onnx Elu node into an equivalent Relax expression."""
+class Unique(OnnxOpConverter):
+ """Converts an onnx Unique node into an equivalent Relax expression."""
@classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
- alpha = float(attr.get("alpha", 1.0))
- return relax.expr.const(-alpha) * relax.op.nn.relu(
- relax.expr.const(1.0) - relax.op.exp(inputs[0])
- ) + relax.op.nn.relu(inputs[0])
+ def _impl_v11(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis = attr.get("axis", None)
+ sorted = bool(attr.get("sorted", 1))
+ # TODO(tvm-team): Add support for return_index, return_inverse,
return_counts
+ return relax.op.unique(data, sorted=sorted, axis=axis)
class HardSigmoid(OnnxOpConverter):
@@ -1966,53 +2438,308 @@ class Not(OnnxOpConverter):
return relax.op.logical_not(inputs[0])
+class DepthToSpace(OnnxOpConverter):
+ """Converts an onnx DepthToSpace node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ block_size = int(attr["blocksize"])
+ mode = attr.get("mode", b"DCR").decode("utf-8")
+ b, c, h, w = inputs[0].struct_info.shape
+ if mode == "DCR":
+ x = relax.op.reshape(
+ inputs[0], (b, block_size, block_size, c // (block_size**2),
h, w)
+ )
+ x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2])
+ return relax.op.reshape(x, (b, c // (block_size**2), h *
block_size, w * block_size))
+ elif mode == "CRD":
+ x = relax.op.reshape(
+ inputs[0], (b, c // (block_size**2), block_size, block_size,
h, w)
+ )
+ x = relax.op.permute_dims(x, [0, 1, 4, 2, 5, 3])
+ return relax.op.reshape(x, (b, c // (block_size**2), h *
block_size, w * block_size))
+ else:
+ raise ValueError(f"Unsupported mode: {mode}, expected DCR or CRD")
+
+
+class SpaceToDepth(OnnxOpConverter):
+ """Converts an onnx SpaceToDepth node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ block_size = int(attr["blocksize"])
+ b, c, h, w = inputs[0].struct_info.shape
+ x = relax.op.reshape(
+ inputs[0], (b, c, h // block_size, block_size, w // block_size,
block_size)
+ )
+ x = relax.op.permute_dims(x, [0, 3, 5, 1, 2, 4])
+ return relax.op.reshape(
+ x, (b, c * block_size * block_size, h // block_size, w //
block_size)
+ )
+
+
+class SequenceConstruct(OnnxOpConverter):
+ """Operator converter for sequence construction op."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ # Construct a tuple from input tensors.
+ return relax.Tuple(inputs)
+
+
+class SequenceEmpty(OnnxOpConverter):
+ """Operator converter for sequence empty op."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ # Construct an empty tuple.
+ return relax.Tuple([])
+
+
+class SequenceErase(OnnxOpConverter):
+ """Operator converter for sequence erase op."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ # Erase tensor from sequence on specified position
+ input_sequence = inputs[0]
+
+ if len(inputs) == 2:
+ position = inputs[1]
+ # Non constant position is not supported.
+ if isinstance(position, relax.Constant):
+ position = int(position.data.numpy())
+ else:
+ raise NotImplementedError("Position must be a constant.")
+ else:
+ position = -1
+
+ seq_len = len(input_sequence)
+ if not -seq_len <= position < seq_len:
+ raise ValueError(
+ f"Position is out of bounds, expected [-{seq_len}, {seq_len}),
got {position}"
+ )
+
+ if position < 0:
+ position = seq_len + position
+ # Convert sequence to a list, insert tensors before erased, and
repackage as Tuple.
+ tensor_list = [input_sequence[i] for i in range(seq_len) if i !=
position]
+ # Create new tuple and return.
+ return relax.Tuple(tensor_list)
+
+
+class SequenceInsert(OnnxOpConverter):
+ """Operator converter for sequence insert op."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ # Insert a new tensor into a tuple of tensors.
+ input_sequence = inputs[0]
+ tensor_to_insert = inputs[1]
+
+ if len(inputs) == 3:
+ position = inputs[2]
+ # Non constant position is not supported.
+ if isinstance(position, relax.Constant):
+ position = position.data.numpy()
+ else:
+ raise NotImplementedError("Position must be a constant.")
+ else:
+ position = -1
+
+ if position < 0:
+ position = len(input_sequence) + position + 1
+ # Convert sequence to a list, insert new tensor, and repackage as
Tuple.
+ tensor_list = [input_sequence[i] for i in range(len(input_sequence))]
+ # Insert new tensor.
+ tensor_list.insert(position, tensor_to_insert)
+ # Create new tuple and return.
+ return relax.Tuple(tensor_list)
+
+
+class SequenceLength(OnnxOpConverter):
+ """Operator converter for sequence length op."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ # Get length of input sequence
+ return relax.const(len(inputs[0]), dtype="int64")
+
+
+class ConcatFromSequence(OnnxOpConverter):
+ """Operator converter for sequence concatenation op."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 0)
+ new_axis = attr.get("new_axis", 0)
+
+ if new_axis == 1:
+ raise NotImplementedError("Insert new axis is not supported yet.")
+
+ return relax.op.concat(inputs[0], axis=axis)
+
+
+class SplitToSequence(OnnxOpConverter):
+ """Operator converter for split to sequence op."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 0)
+ keepdims = attr.get("keepdims", 1)
+
+ input_tensor = inputs[0]
+ input_shape = input_tensor.struct_info.shape
+
+ # If split is not provided, we split all values along axis.
+ if len(inputs) == 1:
+ split = _np.array(1)
+ if not keepdims:
+ raise NotImplementedError("Only keepdims=1 is supported for
now")
+ else:
+ split = inputs[1]
+ if not isinstance(split, relax.Constant):
+ raise ValueError("Only constant split supported for
SplitToSequence")
+ split = split.data.numpy()
+
+ if len(split.shape) == 1 and split.shape[0] > 1:
+ split = _np.cumsum(split)
+ split = list(split[:-1])
+ else:
+ chunk_size, dim_size = int(split), input_shape[axis]
+ if dim_size % chunk_size != 0:
+ raise ValueError(
+ f"Dimension of size {dim_size} along axis {axis} must be "
+ f"evenly divisible by chunk size {chunk_size}"
+ )
+ split = dim_size // chunk_size
+
+ output = relax.op.split(input_tensor, split, axis=axis)
+ return output
+
+
+class SequenceAt(OnnxOpConverter):
+ """Operator converter for sequence at op."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ input_sequence = inputs[0]
+ position = inputs[1]
+ assert isinstance(
+ position, relax.Constant
+ ), "Only constant position supported for SequenceAt"
+ position = int(position.data.numpy())
+ return input_sequence[position]
+
+
def _get_convert_map():
return {
- "MatMul": MatMul,
- "Concat": Concat,
+ # defs/experimental
+ # "Optional": Optional_,
+ # "OptionalHasElement": OptionalHasElement,
+ # "OptionalGetElement": OptionalGetElement,
+ # Binary operators
"Add": Add,
+ "Sub": Sub,
"Mul": Mul,
- "Cast": Cast,
+ "Div": Div,
+ # "Mod": Mod,
+ "Less": Less,
+ "LessOrEqual": LessOrEqual,
+ "Greater": Greater,
+ "GreaterOrEqual": GreaterOrEqual,
+ "Equal": Equal,
+ "BitwiseAnd": BitwiseAnd,
+ "BitwiseOr": BitwiseOr,
+ "BitwiseXor": BitwiseXor,
+ # "BitwiseNot": BitwiseNot,
+ # "BitwiseShift": BitwiseShift,
+ "And": And,
+ "Or": Or,
+ "Xor": Xor,
+ "Not": Not,
+ # Unary operators
+ "Log": Log,
+ "Exp": Exp,
+ "Acos": Acos,
+ "Acosh": Acosh,
+ "Asin": Asin,
+ "Asinh": Asinh,
+ "Atan": Atan,
+ "Atanh": Atanh,
+ "Cos": Cos,
+ "Cosh": Cosh,
+ "Sin": Sin,
+ "Sinh": Sinh,
+ "Tan": Tan,
+ "Tanh": Tanh,
+ "Neg": Neg,
+ "Abs": Abs,
+ "Reciprocal": Reciprocal,
+ "Floor": Floor,
+ "Ceil": Ceil,
+ "Round": Round,
+ "IsInf": IsInf,
+ "IsNaN": IsNaN,
+ "Sqrt": Sqrt,
+ "Relu": Relu,
+ "Selu": Selu,
+ "Mish": Mish,
+ "Trilu": Trilu,
+ "PRelu": PRelu,
+ "LeakyRelu": LeakyRelu,
+ "ThresholdedRelu": ThresholdedRelu,
+ "Elu": Elu,
+ "Gelu": Gelu,
+ "FastGelu": FastGelu,
+ "BiasGelu": BiasGelu,
+ "HardSigmoid": HardSigmoid,
+ "HardSwish": HardSwish,
+ "Sign": Sign,
+ "Softplus": Softplus,
+ "Softsign": Softsign,
+ "Shrink": Shrink,
+ "Erf": Erf,
"Sum": Sum,
- "Gather": Gather,
+ "Min": Min,
+ "Max": Max,
+ "Mean": Mean,
+ "Cast": Cast,
"Gemm": Gemm,
+ "MatMul": MatMul,
+ # "MatMulInteger": MatMulInteger,
+ # "MatMulInteger16": MatMulInteger16,
"Reshape": Reshape,
- "Div": Div,
"Sigmoid": Sigmoid,
"Softmax": Softmax,
+ "LogSoftmax": LogSoftmax,
+ # "Hardmax": Hardmax,
"Transpose": Transpose,
"Unsqueeze": Unsqueeze,
- "Gelu": Gelu,
- "BiasGelu": BiasGelu,
"Where": Where,
+ "Concat": Concat,
"Clip": Clip,
- "Equal": Equal,
"Shape": Shape,
- "Tanh": Tanh,
- "Sqrt": Sqrt,
- "Trilu": Trilu,
- "Relu": Relu,
- "Conv": Conv,
"Pow": Pow,
- "Erf": Erf,
"CumSum": CumSum,
"Squeeze": Squeeze,
"Constant": Constant,
- "Sub": Sub,
- "Sin": Sin,
- "Cos": Cos,
- "Neg": Neg,
- "Abs": Abs,
- "Min": Min,
- "Max": Max,
- "Log": Log,
- "Exp": Exp,
- "Less": Less,
- "LessOrEqual": LessOrEqual,
+ "Gather": Gather,
+ # "GatherElements": GatherElements,
+ # "GatherND": GatherND,
+ "Scatter": Scatter,
+ "ScatterElements": ScatterElements,
+ # "ScatterND": ScatterND,
+ # "Compress": Compress,
+ "Size": Size,
+ # "EyeLike": EyeLike,
+ # Normalization
+ "BatchNormalization": BatchNormalization,
"LayerNormalization": LayerNormalization,
"SkipLayerNormalization": SkipLayerNormalization,
"EmbedLayerNormalization": EmbedLayerNormalization,
"InstanceNormalization": InstanceNormalization,
+ "MeanVarianceNormalization": MeanVarianceNormalization,
# defs/reduction
"ReduceMax": ReduceMax,
"ReduceMin": ReduceMin,
@@ -2026,6 +2753,7 @@ def _get_convert_map():
"ReduceL2": ReduceL2,
"ArgMax": ArgMax,
"ArgMin": ArgMin,
+ "TopK": TopK,
"Expand": Expand,
"ConstantOfShape": ConstantOfShape,
"Slice": Slice,
@@ -2033,23 +2761,42 @@ def _get_convert_map():
"Pad": Pad,
"Split": Split,
"Tile": Tile,
- "BatchNormalization": BatchNormalization,
- "MaxPool": MaxPool,
"AveragePool": AveragePool,
+ "MaxPool": MaxPool,
+ # "LpPool": LpPool,
"GlobalAveragePool": GlobalAveragePool,
+ "GlobalMaxPool": GlobalMaxPool,
+ "GlobalLpPool": GlobalLpPool,
+ "MaxUnpool": MaxUnpool,
+ "Conv": Conv,
+ "ConvTranspose": ConvTranspose,
"Flatten": Flatten,
"Identity": Identity,
"Resize": Resize,
"Einsum": Einsum,
"Range": Range,
- "Greater": Greater,
- "Reciprocal": Reciprocal,
"OneHot": OneHot,
- "Elu": Elu,
- "HardSigmoid": HardSigmoid,
- "HardSwish": HardSwish,
- "Sign": Sign,
- "Not": Not,
+ "Unique": Unique,
+ # "NonZero": NonZero,
+ # "If": If,
+ # "LRN": LRN,
+ # "MaxRoiPool": MaxRoiPool,
+ # "RoiAlign": RoiAlign,
+ # "NonMaxSuppression": NonMaxSuppression,
+ # "GridSample": GridSample,
+ # "Upsample": Upsample,
+ # others
+ "DepthToSpace": DepthToSpace,
+ "SpaceToDepth": SpaceToDepth,
+ # Sequence operators
+ "SequenceConstruct": SequenceConstruct,
+ "SequenceEmpty": SequenceEmpty,
+ "SequenceErase": SequenceErase,
+ "SequenceInsert": SequenceInsert,
+ "SequenceLength": SequenceLength,
+ "ConcatFromSequence": ConcatFromSequence,
+ "SplitToSequence": SplitToSequence,
+ "SequenceAt": SequenceAt,
}
@@ -2269,6 +3016,14 @@ class ONNXGraphImporter:
"Where",
"Cast",
]
+ return_tuple_ops = [
+ "SequenceConstruct",
+ "SequenceEmpty",
+ "SequenceErase",
+ "SequenceInsert",
+ "ConcatFromSequence",
+ "SplitToSequence",
+ ]
for i, inp in enumerate(inputs):
if (
inp is not None
@@ -2277,11 +3032,17 @@ class ONNXGraphImporter:
and op_name not in shape_compatible_ops
):
raise ValueError(f"Node {node.name} cannot handle
ShapeExpr inputs.")
- op = self._convert_operator(op_name, inputs, attr, self.opset)
- # Create struct information for the new operator.
- op = self.bb.normalize(op)
-
- if not isinstance(op, relax.Tuple):
+ try:
+ op = self._convert_operator(op_name, inputs, attr, self.opset)
+ # Create struct information for the new operator.
+ op = self.bb.normalize(op)
+ except TVMError as err:
+ print(f"Error converting operator {op_name}, with inputs:
{inputs}")
+ raise err
+
+ if op_name in return_tuple_ops:
+ outputs_num = 1
+ elif not isinstance(op, relax.Tuple):
if isinstance(op.checked_type, tvm.ir.type.TupleType):
# This is a var bound to a tuple. We need to unpack it and
create
# a new tuple.
@@ -2299,7 +3060,6 @@ class ONNXGraphImporter:
), "Missing outputs during conversion. Expected {} but Got {} in
{}.".format(
len(outputs), outputs_num, op_name
)
-
if outputs_num == 1:
self._nodes[outputs[0]] = op
else:
@@ -2346,10 +3106,10 @@ class ONNXGraphImporter:
def _convert_operator(
self,
op_name: str,
- inputs: List[relax.Function],
+ inputs: List[relax.Expr],
attrs: Dict,
opset: int,
- ) -> relax.Function:
+ ) -> relax.Expr:
"""Convert ONNX operator into a Relax operator.
The converter must specify conversions explicitly for incompatible
name, and
apply handlers to operator attributes.
@@ -2386,7 +3146,7 @@ def from_onnx(
opset: int = None,
keep_params_in_input: bool = False,
sanitize_input_names: bool = True,
-) -> Tuple[IRModule, Dict]:
+) -> IRModule:
"""Convert a ONNX model into an equivalent Relax Function.
ONNX graphs are represented as Python Protobuf objects.
@@ -2413,8 +3173,6 @@ def from_onnx(
-------
mod : tvm.IRModule
The relax module for compilation
- params : dict of str to tvm.nd.NDArray
- The parameter dict to be used by relax
"""
# Error if the model version is below 1.1.0
if model.ir_version < 3:
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
index 4d106ad6d2..0b86e19ce5 100644
--- a/python/tvm/relax/op/set.py
+++ b/python/tvm/relax/op/set.py
@@ -77,7 +77,7 @@ def unique(
return_inverse = PrimValue(return_inverse)
if isinstance(return_counts, bool):
return_counts = PrimValue(return_counts)
- if axis and isinstance(axis, int):
+ if axis is not None and isinstance(axis, int):
axis = PrimValue(axis)
return _ffi_api.unique( # type: ignore
x, sorted, return_index, return_inverse, return_counts, axis
@@ -91,6 +91,7 @@ def numpy_unique(
return_index: int,
return_inverse: int,
return_counts: int,
+ axis: Optional[int] = None,
) -> tvm.nd.array:
"""Returns the unique elements of the input tensor.
@@ -103,8 +104,9 @@ def numpy_unique(
raise NotImplementedError("missing support return_inverse or
return_counts set to true")
x_numpy = x.numpy()
# TODO(prakalp): use torch.unique instead of numpy when torch is installed
in ci.
- output_sorted_numpy, indices = np.unique(x_numpy, return_index=True)
+ output_sorted_numpy, indices = np.unique(x_numpy, return_index=True,
axis=axis)
+
if sorted:
return tvm.nd.array(output_sorted_numpy)
- output_numpy = [x_numpy.flatten()[index] for index in
builtins.sorted(indices, reverse=True)]
+ output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis)
return tvm.nd.array(output_numpy)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 809d231fd3..8317d4504e 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -171,21 +171,16 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) ->
Expr:
"and thus cannot be legalized by TOPI"
)
return call
- if call.attrs.groups != 1:
- logging.info(
- "TOPI conv1d_transpose does not support groups other than 1, "
- "and thus cannot be legalized by TOPI"
- )
- return call
return bb.call_te(
- topi.nn.conv1d_transpose_ncw,
+ topi.nn.group_conv1d_transpose_ncw,
call.args[0],
call.args[1],
stride=call.attrs.strides,
padding=call.attrs.padding,
out_dtype=call.struct_info.dtype,
output_padding=call.attrs.output_padding,
+ groups=call.attrs.groups,
primfunc_name_hint="conv1d_transpose",
)
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 0e7cfbd7c0..2837ad2185 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -21,7 +21,7 @@ ONNX testcases
This file is a test script to test Relax ONNX frontend coverage.
"""
-from typing import Dict, Optional
+from typing import Dict, List, Literal, Optional
import numpy as np
import onnx
@@ -118,6 +118,7 @@ def check_correctness(
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
# Legalize any relax ops into tensorir.
tvm_model = relax.transform.LegalizeOps()(tvm_model)
+ print(tvm_model)
# Separate model from parameters.
tvm_model, params = relax.frontend.detach_params(tvm_model)
@@ -137,25 +138,31 @@ def check_correctness(
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
# Wrap as a list if there is only one output.
- if isinstance(tvm_output, tvm.nd.NDArray):
+ if len(ort_output) == 1:
+ # Do not check the output number for TVM
+ # As for sequence output, the TVM output is a Tuple
+ # while the ONNX output number is one, which is a list
tvm_output = [tvm_output]
- # If the output is a shape tuple, convert it to an ndarray for comparison.
- if isinstance(tvm_output, tvm.runtime.ShapeTuple):
- tvm_output = [tvm.nd.array([int(i) for i in tvm_output])]
- tvm_num_outputs = len(tvm_output)
- # Shape tuples need to be handled specially.
- if isinstance(tvm_output, tvm.runtime.ShapeTuple):
- tvm_num_outputs = 1
+ def _check_output(tvm_out, ort_out):
+ if isinstance(tvm_out, tuple) and isinstance(ort_out,
(tvm.runtime.ShapeTuple, list)):
+ assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
+ for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
+ _check_output(tvm_out_i, ort_out_i)
+ elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out,
np.ndarray):
+ tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol,
atol=atol)
+ elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and
isinstance(ort_out, np.ndarray):
+ shape_out = tvm.nd.array([int(i) for i in tvm_out])
+ tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol,
atol=atol)
+ else:
+ raise ValueError(f"Unsupported types: {type(tvm_out)},
{type(ort_out)}")
# Check that number of outputs match.
- assert tvm_num_outputs == len(ort_output), "Unequal number of outputs"
-
+ assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
for (tvm_out, ort_out) in zip(tvm_output, ort_output):
# TODO Allow configurable tolerance.
- # Sometimes None is used to indicate an unused output.
if ort_out is not None:
- tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol,
atol=atol)
+ _check_output(tvm_out, ort_out)
@pytest.mark.parametrize(
@@ -187,35 +194,61 @@ def test_sanitize(input_names, expected_names):
assert param.name_hint == expected_names[i]
-def verify_unary(op_name, shape, attrs={}, domain=None,
dtype=TensorProto.FLOAT):
+def verify_unary(
+ op_name,
+ shape,
+ attrs={},
+ domain=None,
+ input_dtype=TensorProto.FLOAT,
+ output_dtype=TensorProto.FLOAT,
+ opset=14,
+):
test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain)
graph = helper.make_graph(
[test_node],
"elemwise_test",
inputs=[
- helper.make_tensor_value_info("x", dtype, shape),
+ helper.make_tensor_value_info("x", input_dtype, shape),
],
- outputs=[helper.make_tensor_value_info("y", dtype, shape)],
+ outputs=[helper.make_tensor_value_info("y", output_dtype, shape)],
)
model = helper.make_model(graph, producer_name="elemwise_test")
- check_correctness(model)
+ check_correctness(model, opset=opset)
-def verify_binary(op_name, shape_a, shape_b, shape_c, attrs={}, domain=None):
+def verify_binary(
+ op_name, shape_a, shape_b, shape_c, attrs={}, domain=None,
dtype=TensorProto.FLOAT, opset=14
+):
test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs,
domain=domain)
graph = helper.make_graph(
[test_node],
"binary_test",
inputs=[
- helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a),
- helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b),
+ helper.make_tensor_value_info("a", dtype, shape_a),
+ helper.make_tensor_value_info("b", dtype, shape_b),
],
- outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT,
shape_c)],
+ outputs=[helper.make_tensor_value_info("c", dtype, shape_c)],
)
model = helper.make_model(graph, producer_name="binary_test")
- check_correctness(model)
+ check_correctness(model, opset=opset)
+
+
+def verify_binary_scalar(op_name, attrs={}, domain=None,
dtype=TensorProto.INT32, opset=14):
+ a = make_constant_node("a", dtype, [], [4])
+ b = make_constant_node("b", dtype, [], [8])
+ test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs,
domain=domain)
+ graph = helper.make_graph(
+ [a, b, test_node],
+ "binary_test",
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("c", dtype, ())],
+ )
+
+ model = helper.make_model(graph, producer_name="binary_test")
+ # NOTE: explicitly pass inputs to avoid numerical error
+ check_correctness(model, opset=opset)
def verify_compare(op_name, shape, attrs={}, domain=None):
@@ -289,16 +322,95 @@ def test_concat():
verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0})
-def test_add():
- verify_binary("Add", [1, 32], [1, 32], [1, 32])
[email protected]("op_name", ["Add", "Sub", "Mul", "Div", "Pow"])
+def test_binary(op_name: str):
+ verify_binary(op_name, [1, 32], [1, 32], [1, 32])
+ verify_binary_scalar(op_name)
+
+
[email protected]("num_inputs", [1, 2, 4])
[email protected]("op_name", ["Min", "Max", "Sum", "Mean"])
+def test_multi_input(op_name: str, num_inputs: int):
+ input_shape = [32, 32]
+ input_var = ["i" + str(i) for i in range(num_inputs)]
+ input_values = [
+ helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for
var in input_var
+ ]
+ test_node = helper.make_node(op_name, input_var, ["c"])
+ graph = helper.make_graph(
+ [test_node],
+ "multi_input_test",
+ inputs=input_values,
+ outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT,
input_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="multi_input_test")
+ check_correctness(model)
-def test_mul():
- verify_binary("Mul", [1, 32], [1, 32], [1, 32])
[email protected]("op_name", ["Less", "LessOrEqual", "Greater",
"GreaterOrEqual"])
+def test_compare(op_name: str):
+ verify_compare(op_name, [1, 32])
-def test_sum():
- verify_binary("Sum", [1, 32], [1, 32], [1, 32])
[email protected]("op_name", ["And", "Or", "Xor"])
+def test_binary_bool(op_name: str):
+ verify_binary(op_name, [32, 32], [32, 32], [32, 32],
dtype=TensorProto.BOOL)
+
+
[email protected](
+ "op_name",
+ [
+ "Sin",
+ "Cos",
+ "Tan",
+ "Sinh",
+ "Cosh",
+ "Tanh",
+ "Asin",
+ "Acos",
+ "Atan",
+ "Asinh",
+ "Acosh",
+ "Atanh",
+ "Neg",
+ "Abs",
+ "Log",
+ "Exp",
+ "Not",
+ "Reciprocal",
+ "Floor",
+ "Ceil",
+ "Round",
+ "IsInf",
+ "IsNaN",
+ "Sqrt",
+ "Relu",
+ "Elu",
+ "HardSwish",
+ "Sign",
+ "Softplus",
+ "Softsign",
+ "Erf",
+ "Sigmoid",
+ "Softmax",
+ "LogSoftmax",
+ "Identity",
+ ],
+)
+def test_unary(op_name: str):
+ input_dtype = TensorProto.FLOAT
+ if op_name in [
+ "IsNaN",
+ "IsInf",
+ ]:
+ pytest.skip(f"Skipping test {op_name} because current LegalizeOps does
not support it.")
+ elif op_name == "Not":
+ input_dtype = TensorProto.BOOL
+ output_dtype = TensorProto.BOOL
+ else:
+ output_dtype = TensorProto.FLOAT
+ verify_unary(op_name, [32, 32], input_dtype=input_dtype,
output_dtype=output_dtype)
@pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT,
TensorProto.FLOAT16])
@@ -350,6 +462,44 @@ def test_gather():
_verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
[email protected]("axis", [0, 1, 2])
[email protected](("name", "opset"), [("Scatter", 10),
("ScatterElements", 11)])
+def test_scatter(axis: int, name: str, opset: int):
+ if axis != 1:
+ pytest.skip("The current topi impl is wrong, which only works for
axis=1")
+ input_shape = [16, 16, 16]
+ indices_shape = [8, 8, 8]
+ updates_shape = [8, 8, 8]
+ output_shape = [16, 16, 16]
+ node = helper.make_node(name, ["data", "indices", "updates"], ["output"],
axis=axis)
+ graph = helper.make_graph(
+ [node],
+ "scatter_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT,
input_shape),
+ helper.make_tensor_value_info("indices", TensorProto.INT64,
indices_shape),
+ helper.make_tensor_value_info("updates", TensorProto.FLOAT,
updates_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
output_shape)],
+ )
+ model = helper.make_model(graph, producer_name="scatter_test")
+ indices = np.random.randint(0, 16, indices_shape)
+ check_correctness(model, inputs={"indices": indices}, opset=opset)
+
+
+def test_size():
+ test_node = helper.make_node("Size", ["x"], ["y"])
+ graph = helper.make_graph(
+ [test_node],
+ "size_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 3,
3])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.INT64, [3])],
+ )
+
+ model = helper.make_model(graph, producer_name="size_test")
+ check_correctness(model)
+
+
@pytest.mark.parametrize("alpha", [None, 0.25, 1.0])
@pytest.mark.parametrize("beta", [None, 0.35, 1.0])
@pytest.mark.parametrize("useC", [False, True])
@@ -408,18 +558,6 @@ def test_reshape(in_shape, shape, out_shape):
check_correctness(model, inputs=input_values)
-def test_div():
- verify_binary("Div", [32, 32], [32, 32], [32, 32])
-
-
-def test_sigmoid():
- verify_unary("Sigmoid", [32, 32])
-
-
-def test_softmax():
- verify_unary("Softmax", [32, 32, 32])
-
-
def test_transpose():
verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]})
@@ -567,28 +705,33 @@ def test_shape():
check_correctness(model)
-def test_tanh():
- verify_unary("Tanh", [9, 8, 7, 6])
[email protected]("upper", [True, False])
+def test_trilu(upper: bool):
+ verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper})
-def test_sqrt():
- verify_unary("Sqrt", [32, 32])
+def test_selu():
+ verify_unary("Selu", [3, 32, 32])
+ verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3})
-def test_relu():
- verify_unary("Relu", [32, 32])
[email protected](reason="opset 18 is not supported in CI")
+def test_mish():
+ verify_unary("Mish", [3, 32, 32], opset=18)
-def test_tril():
- verify_unary("Trilu", [3, 5, 5], attrs={"upper": False})
+def test_prelu():
+ verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32])
-def test_triu():
- verify_unary("Trilu", [3, 5, 5], attrs={"upper": True})
+def test_thresholded_relu():
+ verify_unary("ThresholdedRelu", [3, 32, 32])
+ verify_unary("ThresholdedRelu", [3, 32, 32], attrs={"alpha": -0.01})
-def test_elu():
- verify_unary("Elu", [32, 32])
+def test_leakyrelu():
+ verify_unary("LeakyRelu", [32, 32])
+ verify_unary("LeakyRelu", [32, 32], attrs={"alpha": 0.2})
def test_hardsigmoid():
@@ -597,30 +740,40 @@ def test_hardsigmoid():
verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta":
0.6})
-def test_hardswish():
- verify_unary("HardSwish", [32, 32])
-
-
-def test_sign():
- verify_unary("Sign", [32, 32])
-
-
-def test_not():
- verify_unary("Not", [32, 32], dtype=TensorProto.BOOL)
+def test_shrink():
+ verify_unary("Shrink", [32, 32])
+ verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1})
-def test_conv():
- def _verify_conv(input_shape, weight_shape, output_shape):
[email protected]("stride", [1, 2])
[email protected]("dilation", [1, 2])
[email protected]("bias", [True, False])
[email protected]("pad", [0, 2])
+def test_conv(stride: int, dilation: int, pad: int, bias: bool):
+ def _verify_conv(input_shape, weight_shape):
+ nd = len(weight_shape) - 2
+ output_shape = [input_shape[0], weight_shape[0]] + [
+ (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1)
// stride + 1
+ for i in range(2, len(input_shape))
+ ]
bias_shape = [output_shape[1]]
- conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"])
+ conv_node = helper.make_node(
+ "Conv",
+ inputs=["x", "w"] + (["b"] if bias else []),
+ outputs=["y"],
+ strides=[stride] * nd,
+ dilations=[dilation] * nd,
+ pads=[pad] * nd * 2,
+ group=input_shape[1] // weight_shape[1],
+ )
graph = helper.make_graph(
[conv_node],
"conv_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT,
input_shape),
helper.make_tensor_value_info("w", TensorProto.FLOAT,
weight_shape),
- helper.make_tensor_value_info("b", TensorProto.FLOAT,
bias_shape),
- ],
+ ]
+ + ([helper.make_tensor_value_info("b", TensorProto.FLOAT,
bias_shape)] if bias else []),
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
output_shape)],
)
@@ -628,20 +781,61 @@ def test_conv():
check_correctness(model, atol=1e-4)
# Conv1D
- _verify_conv([3, 12, 32], [4, 12, 3], [3, 4, 30])
+ _verify_conv([3, 4, 32], [4, 4, 3])
+ _verify_conv([3, 4, 32], [2, 4, 3]) # group=2
# Conv2D
- _verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30])
+ _verify_conv([3, 4, 32, 32], [4, 4, 3, 3])
+ _verify_conv([3, 4, 32, 32], [2, 4, 3, 3]) # group=2
# Conv3D
- _verify_conv([3, 12, 32, 32, 32], [4, 12, 3, 3, 3], [3, 4, 30, 30, 30])
+ _verify_conv([3, 4, 32, 32, 32], [4, 4, 3, 3, 3])
+ _verify_conv([3, 4, 32, 32, 32], [2, 4, 3, 3, 3]) # group=2
+
+
[email protected]("stride", [1, 2])
[email protected]("dilation", [1])
[email protected]("bias", [True, False])
[email protected]("pad", [0, 2])
+def test_conv_transpose(stride: int, dilation: int, pad: int, bias: bool):
+ def _verify_conv_transpose(input_shape, weight_shape):
+ nd = len(weight_shape) - 2
+ output_shape = [input_shape[0], weight_shape[0]] + [
+ (input_shape[i] - 1) * stride - 2 * pad + dilation *
(weight_shape[i] - 1) + 1
+ for i in range(2, len(input_shape))
+ ]
+ bias_shape = [output_shape[1]]
+ conv_node = helper.make_node(
+ "ConvTranspose",
+ inputs=["x", "w"] + (["b"] if bias else []),
+ outputs=["y"],
+ strides=[stride] * nd,
+ dilations=[dilation] * nd,
+ pads=[pad] * nd * 2,
+ group=input_shape[1] // weight_shape[1],
+ )
+ graph = helper.make_graph(
+ [conv_node],
+ "conv_transpose_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT,
input_shape),
+ helper.make_tensor_value_info("w", TensorProto.FLOAT,
weight_shape),
+ ]
+ + ([helper.make_tensor_value_info("b", TensorProto.FLOAT,
bias_shape)] if bias else []),
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
output_shape)],
+ )
+ model = helper.make_model(graph, producer_name="conv_transpose_test")
+ check_correctness(model, atol=1e-4)
-def test_pow():
- verify_binary("Pow", [32, 32], [32, 32], [32, 32])
+ # ConvTranspose1D
+ _verify_conv_transpose([3, 4, 32], [4, 4, 3])
+ _verify_conv_transpose([3, 4, 32], [4, 2, 3]) # group=2
+ # ConvTranspose2D
+ _verify_conv_transpose([3, 4, 32, 32], [4, 4, 3, 3])
+ _verify_conv_transpose([3, 4, 32, 32], [4, 2, 3, 3]) # group=2
-def test_erf():
- verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT)
- verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT16)
+def test_pow():
+ verify_binary("Pow", [32, 32], [32, 32], [32, 32])
@pytest.mark.parametrize("reverse", [False])
@@ -712,46 +906,6 @@ def test_const():
check_correctness(model)
-def test_sub():
- verify_binary("Sub", [32, 16], [32, 16], [32, 16])
-
-
-def test_min():
- verify_binary("Min", [32, 16], [32, 16], [32, 16])
-
-
-def test_max():
- verify_binary("Max", [32, 16], [32, 16], [32, 16])
-
-
-def test_sin():
- verify_unary("Sin", [32, 16])
-
-
-def test_cos():
- verify_unary("Cos", [32, 16])
-
-
-def test_identity():
- verify_unary("Identity", [32, 16])
-
-
-def test_neg():
- verify_unary("Neg", [32, 16])
-
-
-def test_abs():
- verify_unary("Abs", [32, 16])
-
-
-def test_log():
- verify_unary("Log", [32, 16])
-
-
-def test_exp():
- verify_unary("Exp", [32, 16])
-
-
def test_instance_norm():
verify_ternary(
"InstanceNormalization", [1, 3, 32, 32], [3], [3], [1, 3, 32, 32],
attrs={"epsilon": 1e-12}
@@ -761,6 +915,11 @@ def test_instance_norm():
)
+def test_mean_variance_norm():
+ verify_unary("MeanVarianceNormalization", [1, 3, 32, 32])
+ verify_unary("MeanVarianceNormalization", [1, 3, 32, 32], attrs={"axes":
(1, 2, 3)})
+
+
def test_layer_norm():
layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"],
["d"], epsilon=1e-12)
@@ -1075,9 +1234,36 @@ def test_arg_min_max(in_dtype, axis, keepdims):
verify_arg_min_max([3, 4, 4], in_dtype, "ArgMin", axis, keepdims)
[email protected]("axis", [-1, 0, 1])
[email protected]("largest", [True, False])
+def test_topk(axis: int, largest: int):
+ in_shape = [32, 32, 32]
+ k_value = 4
+ out_shape = in_shape
+ out_shape[axis] = k_value
+ k = make_constant_node("k", TensorProto.INT64, [1], [k_value])
+ node = onnx.helper.make_node(
+ "TopK",
+ inputs=["data", "k"],
+ outputs=["values", "indices"],
+ axis=axis,
+ largest=largest,
+ )
+ graph = helper.make_graph(
+ [k, node],
+ "topk_test",
+ inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT,
in_shape)],
+ outputs=[
+ helper.make_tensor_value_info("values", TensorProto.FLOAT,
out_shape),
+ helper.make_tensor_value_info("indices", TensorProto.INT64,
out_shape),
+ ],
+ )
+ model = helper.make_model(graph, producer_name="topk_test")
+
+ check_correctness(model)
+
+
@pytest.mark.parametrize("dynamic", [False, True])
-# TODO(jwfromm) Current approach to dynamic expand is technically not well
formed. Reenable once fixed.
[email protected]("Produces ill-formed IR")
def test_expand(dynamic):
if dynamic:
# TODO: Support dynamic shape for Expand
@@ -1586,14 +1772,6 @@ def test_range():
check_correctness(model)
-def test_less():
- verify_compare("Less", [32, 32])
-
-
-def test_less_equal():
- verify_compare("LessOrEqual", [32, 32])
-
-
def test_batch_norm():
batch_norm_node = helper.make_node(
"BatchNormalization", ["x", "s", "bias", "mean", "var"], ["y"],
epsilon=1e-2
@@ -1811,17 +1989,58 @@ def test_global_average_pool():
verify_unary("GlobalAveragePool", [1, 3, 32, 32, 32])
+def test_global_max_pool():
+ verify_unary("GlobalMaxPool", [1, 3, 32])
+ verify_unary("GlobalMaxPool", [1, 3, 32, 32])
+ verify_unary("GlobalMaxPool", [1, 3, 32, 32, 32])
+
+
[email protected]("p", [1, 2, 3])
+def test_global_lp_pool(p: int):
+ verify_unary("GlobalLpPool", [1, 3, 32], attrs={"p": p})
+ verify_unary("GlobalLpPool", [1, 3, 32, 32], attrs={"p": p})
+ verify_unary("GlobalLpPool", [1, 3, 32, 32, 32], attrs={"p": p})
+
+
[email protected]("kernel_shape", [[2, 2], [3, 3]])
[email protected]("pads", [None, [1, 1, 1, 1]])
[email protected]("strides", [None, [2, 2]])
+def test_maxunpool(kernel_shape, pads, strides):
+ input_shape = [16, 3, 16, 16]
+ input_names = ["X", "I"]
+ input_info = [
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape),
+ helper.make_tensor_value_info("I", TensorProto.INT64, input_shape),
+ ]
+
+ attrs = {"kernel_shape": kernel_shape}
+ if pads is not None:
+ attrs["pads"] = pads
+ if strides is not None:
+ attrs["strides"] = strides
+
+ node = helper.make_node("MaxUnpool", inputs=input_names, outputs=["y"],
**attrs)
+
+ graph = helper.make_graph(
+ [node],
+ "maxunpool_test",
+ inputs=input_info,
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, None)],
+ )
+
+ max_random = int(np.prod(np.array(kernel_shape)))
+ indices = np.random.randint(0, max_random, size=input_shape)
+
+ model = helper.make_model(graph, producer_name="maxunpool_test")
+ check_correctness(model, inputs={"I": indices})
+
+
def test_flatten():
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0})
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1})
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2})
-def test_greater():
- verify_compare("Greater", [32, 32])
- verify_compare("Greater", [64, 16])
-
-
def test_onehot():
one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"],
["y"], axis=1)
graph = helper.make_graph(
@@ -1844,8 +2063,189 @@ def test_onehot():
check_correctness(model, inputs=values)
-def test_reciprocal():
- verify_unary("Reciprocal", [3, 32, 32])
[email protected]("axis", [None, 0, 1, -1])
[email protected]("sorted", [0, 1])
+def test_unique(axis: Optional[int], sorted: int):
+ input_shape = [32, 32]
+ if axis is None:
+ output_shape = [-1]
+ else:
+ output_shape = [32, 32]
+ output_shape[axis] = -1
+ unique_node = helper.make_node("Unique", ["x"], ["y"], axis=axis,
sorted=sorted)
+ graph = helper.make_graph(
+ [unique_node],
+ "unique_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
input_shape)],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
output_shape)],
+ )
+ model = helper.make_model(graph, producer_name="unique_test")
+ check_correctness(model)
+
+
[email protected]("mode", ["DCR", "CRD"])
+def test_depth_to_space(mode: Literal["DCR", "CRD"]):
+ in_shape = [1, 8, 2, 3]
+ out_shape = [1, 2, 4, 6]
+ blocksize = 2
+ node = onnx.helper.make_node(
+ "DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blocksize,
mode=mode
+ )
+ graph = helper.make_graph(
+ [node],
+ "depth_to_space_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
in_shape)],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
out_shape)],
+ )
+ model = helper.make_model(graph, producer_name="depth_to_space_test")
+
+ check_correctness(model)
+
+
+def test_space_to_depth():
+ in_shape = [1, 2, 4, 6]
+ out_shape = [1, 8, 2, 3]
+ blocksize = 2
+ node = onnx.helper.make_node("SpaceToDepth", inputs=["x"], outputs=["y"],
blocksize=blocksize)
+ graph = helper.make_graph(
+ [node],
+ "space_to_depth_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
in_shape)],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
out_shape)],
+ )
+ model = helper.make_model(graph, producer_name="space_to_depth_test")
+
+ check_correctness(model)
+
+
+def construct_sequence(input_shape: List[int], num_tensors: int, name: str =
"sequence"):
+ inputs = [f"data{i}" for i in range(num_tensors)]
+ sequence_construct_node = helper.make_node("SequenceConstruct", inputs,
[name])
+ graph_inputs = [
+ helper.make_tensor_value_info(f"data{i}", TensorProto.FLOAT,
input_shape)
+ for i in range(num_tensors)
+ ]
+ return sequence_construct_node, graph_inputs
+
+
+def make_constant_node(name: str, data_type: int, dims: List[int], vals:
List[int]):
+ return helper.make_node(
+ "Constant",
+ inputs=[],
+ outputs=[name],
+ value=helper.make_tensor(name=name, data_type=data_type, dims=dims,
vals=vals),
+ )
+
+
+def test_sequence_construct():
+ node, graph_inputs = construct_sequence(input_shape=[32, 32],
num_tensors=2)
+ graph = helper.make_graph(
+ [node],
+ "test_sequence_construct",
+ inputs=graph_inputs,
+ outputs=[helper.make_tensor_sequence_value_info("sequence",
TensorProto.FLOAT, [32, 32])],
+ )
+ model = helper.make_model(graph, producer_name="test_sequence_construct")
+ check_correctness(model)
+
+
+def test_sequence_empty():
+ sequence_empty_node = helper.make_node("SequenceEmpty", [], ["sequence"])
+ graph = helper.make_graph(
+ [sequence_empty_node],
+ "test_sequence_empty",
+ inputs=[],
+ outputs=[helper.make_tensor_sequence_value_info("sequence",
TensorProto.FLOAT, [])],
+ )
+ model = helper.make_model(graph, producer_name="test_sequence_empty")
+ check_correctness(model)
+
+
[email protected]("explicit_position", [True, False])
+def test_sequence_erase(explicit_position: bool):
+ seq_node, graph_inputs = construct_sequence(input_shape=[32, 32],
num_tensors=4)
+ index = make_constant_node("index", TensorProto.INT64, (), [1])
+ node_input = ["sequence", "index"] if explicit_position else ["sequence"]
+ sequence_erase_node = helper.make_node("SequenceErase", node_input,
["output"])
+ graph = helper.make_graph(
+ [index, seq_node, sequence_erase_node],
+ "test_sequence_erase",
+ inputs=graph_inputs,
+ outputs=[helper.make_tensor_sequence_value_info("output",
TensorProto.FLOAT, [32, 32])],
+ )
+ model = helper.make_model(graph, producer_name="test_sequence_erase")
+ check_correctness(model)
+
+
[email protected]("explicit_position", [True, False])
+def test_sequence_insert(explicit_position: bool):
+ seq_node, graph_inputs = construct_sequence(input_shape=[32, 32],
num_tensors=4)
+ index = make_constant_node("index", TensorProto.INT64, (), [0])
+ node_input = ["sequence", "value", "index"] if explicit_position else
["sequence", "value"]
+ sequence_insert_node = helper.make_node("SequenceInsert", node_input,
["output"])
+ graph = helper.make_graph(
+ [index, seq_node, sequence_insert_node],
+ "test_sequence_insert",
+ inputs=[*graph_inputs, helper.make_tensor_value_info("value",
TensorProto.FLOAT, [32, 32])],
+ outputs=[helper.make_tensor_sequence_value_info("output",
TensorProto.FLOAT, [32, 32])],
+ )
+ model = helper.make_model(graph, producer_name="test_sequence_insert")
+ check_correctness(model)
+
+
[email protected]("new_axis", [0, 1])
+def test_concat_from_sequence(new_axis: Literal[0, 1]):
+ if new_axis == 1:
+ pytest.skip("ConcatFromSequence with new_axis=1 is not supported yet")
+ seq_node, graph_inputs = construct_sequence(input_shape=[32, 32],
num_tensors=2)
+ concat_from_sequence_node = helper.make_node(
+ "ConcatFromSequence", ["sequence"], ["output"], axis=1
+ )
+ graph = helper.make_graph(
+ [seq_node, concat_from_sequence_node],
+ "test_concat_from_sequence",
+ inputs=graph_inputs,
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
[64, 32])],
+ )
+ model = helper.make_model(graph, producer_name="test_concat_from_sequence")
+ check_correctness(model)
+
+
[email protected]("split", [2, [16, 48]])
+def test_split_to_sequence(split):
+ split_to_sequence_node = helper.make_node(
+ "SplitToSequence",
+ ["data", "split"],
+ ["output"],
+ axis=0,
+ )
+ split_shape = [len(split)] if isinstance(split, list) else ()
+ split_node = make_constant_node(
+ "split", TensorProto.INT64, split_shape, [split] if isinstance(split,
int) else split
+ )
+ graph = helper.make_graph(
+ [split_node, split_to_sequence_node],
+ "test_split_to_sequence",
+ inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [64,
32])],
+ outputs=[helper.make_tensor_sequence_value_info("output",
TensorProto.FLOAT, [32, 32])],
+ )
+ model = helper.make_model(graph, producer_name="test_split_to_sequence")
+ check_correctness(model)
+
+
+def test_sequence_at():
+ seq_node, graph_inputs = construct_sequence(input_shape=[32, 32],
num_tensors=4)
+ index = make_constant_node("index", TensorProto.INT64, (), [1])
+ node_input = ["sequence", "index"]
+ sequence_at_node = helper.make_node("SequenceAt", node_input, ["output"])
+ graph = helper.make_graph(
+ [index, seq_node, sequence_at_node],
+ "test_sequence_at",
+ inputs=graph_inputs,
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
[32, 32])],
+ )
+ model = helper.make_model(graph, producer_name="test_sequence_at")
+ check_correctness(model)
def test_symbolic_shape_deduction():
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index fcb8727d85..a80b988d06 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -60,7 +60,7 @@ def test_unique(exec_mode):
result, result_sorted = run_cpu(InputModule, "foo", data,
exec_mode=exec_mode)
expected_output_sorted, indices = np.unique(data_numpy, return_index=True)
- expected_output = [data_numpy.flatten()[index] for index in
sorted(indices, reverse=True)]
+ expected_output = [data_numpy.flatten()[index] for index in
sorted(indices)]
np.testing.assert_array_equal(expected_output_sorted,
result_sorted.numpy())
np.testing.assert_array_equal(expected_output, result.numpy())
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index d03d48968d..12436cf802 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -204,6 +204,53 @@ def test_conv1d_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_conv1d_transpose():
+ # fmt: off
+ @I.ir_module
+ class Conv1dTranspose:
+ @R.function
+ def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((128, 16,
3), "float32")):
+ gv = R.nn.conv1d_transpose(x, w, strides=2, padding=1, dilation=1,
output_padding=1, groups=8)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def conv1d_transpose(x: T.Buffer((T.int64(2), T.int64(128),
T.int64(28)), "float32"), w: T.Buffer((T.int64(128), T.int64(16), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56)),
"float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ data_dilate = T.alloc_buffer((T.int64(2), T.int64(128),
T.int64(55)))
+ data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58)))
+ kernel = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3)))
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(55)):
+ with T.block("data_dilate"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ data_dilate[v_i0, v_i1, v_i2] = T.if_then_else(v_i2 %
T.int64(2) == T.int64(0), x[v_i0, v_i1, v_i2 // T.int64(2)], T.float32(0.0))
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(58)):
+ with T.block("data_pad"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ data_pad[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <=
v_i2 and v_i2 < T.int64(56), data_dilate[v_i0, v_i1, v_i2 - T.int64(1)],
T.float32(0.0))
+ for o, i, w_1 in T.grid(T.int64(16), T.int64(128), T.int64(3)):
+ with T.block("kernel"):
+ v_o, v_i, v_w = T.axis.remap("SSS", [o, i, w_1])
+ kernel[v_o, v_i, v_w] = w[v_i, v_o, T.int64(2) - v_w]
+ for b, c, w_1, dc, dw in T.grid(T.int64(2), T.int64(128),
T.int64(56), T.int64(16), T.int64(3)):
+ with T.block("compute"):
+ v_b, v_c, v_w, v_dc, v_dw = T.axis.remap("SSSRR", [b, c,
w_1, dc, dw])
+ with T.init():
+ compute[v_b, v_c, v_w] = T.float32(0.0)
+ compute[v_b, v_c, v_w] = compute[v_b, v_c, v_w] +
data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_w + v_dw] * kernel[v_c
% T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dw]
+
+ @R.function
+ def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((128,
16, 3), dtype="float32")) -> R.Tensor((2, 128, 56), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.conv1d_transpose, (x, w),
out_sinfo=R.Tensor((2, 128, 56), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Conv1dTranspose)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_conv2d():
# fmt: off
@tvm.script.ir_module