This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new dddce91b3f [NN][Op]ConvTranspose1D (#15741)
dddce91b3f is described below
commit dddce91b3f0b7b3254460a590d1ee92d6ee9ba90
Author: Lesheng Jin <[email protected]>
AuthorDate: Fri Sep 15 22:59:48 2023 -0700
[NN][Op]ConvTranspose1D (#15741)
This pr introduces a new op ConvTranspose1D to nn.Module.
---
python/tvm/relax/frontend/nn/__init__.py | 11 +++-
python/tvm/relax/frontend/nn/modules.py | 67 +++++++++++++++++++++
python/tvm/relax/frontend/nn/op.py | 80 ++++++++++++++++++++++++++
tests/python/relax/test_frontend_nn_modules.py | 26 +++++++++
4 files changed, 183 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/nn/__init__.py
b/python/tvm/relax/frontend/nn/__init__.py
index b4f3c250fb..7e086abeee 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -17,6 +17,15 @@
"""A PyTorch-like API to build IRModules."""
from . import op, spec
from .core import Effect, Module, ModuleList, Parameter, Tensor, ExternModule
-from .modules import Embedding, IOEffect, KVCache, Linear, Conv1D, LayerNorm,
RMSNorm
+from .modules import (
+ Embedding,
+ IOEffect,
+ KVCache,
+ Linear,
+ Conv1D,
+ ConvTranspose1D,
+ LayerNorm,
+ RMSNorm,
+)
from .op import *
from .subroutine import SubroutineMixin
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index f07bcb1ce3..4e612dfbc2 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -312,6 +312,73 @@ class Conv2D(Module):
)
+class ConvTranspose1D(Module):
+ """
+ Module for ConvTranspose1D layer.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ output_padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ dtype: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.output_padding = output_padding
+ self.dilation = dilation
+ self.groups = groups
+
+ self.weight = Parameter(
+ (
+ self.in_channels,
+ int(self.out_channels // self.groups),
+ self.kernel_size,
+ ),
+ dtype,
+ )
+ if bias:
+ self.bias = Parameter((self.out_channels,), dtype)
+ else:
+ self.bias = None
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ Forward method for convtranspose1d layer.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input tensor.
+
+ Returns
+ -------
+ ret : Tensor
+ The output tensor for the convtranspose1d layer.
+ """
+ return op.conv1d_transpose(
+ x,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.output_padding,
+ self.dilation,
+ self.groups,
+ )
+
+
class LayerNorm(Module):
"""
Module for Layer Normalization
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 1f00eb058a..8afeb2c118 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -416,6 +416,86 @@ def conv2d(
return _wrap_nested(conv_out, name)
+def conv1d_transpose(
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor] = None,
+ stride: Optional[Union[int, Tuple[int]]] = 1,
+ padding: Optional[Union[int, Tuple[int, ...]]] = 0,
+ output_padding: Optional[Union[int, Tuple[int]]] = 0,
+ dilation: Optional[Union[int, Tuple]] = 1,
+ groups: Optional[int] = 1,
+ name: str = "conv1d_transpose",
+) -> Tensor:
+ """1D transposed convolution operator.
+
+ This operator can be seen as the gradient operator of conv1d.
+
+ The output shape can be explained in the simple case when `data_layout ==
"NCW"` and
+ `kernel_layout == "IOW"`. Suppose `data` has shape `(N, in_channel,
in_w)`, `weight` has
+ shape `(in_channel, out_channel, weight_w)`, we need to assure that
`in_channel % groups == 0`.
+ The shape of the output will be `(N, out_channel * groups, out_w)`, where
+
+ - `out_w = ((in_w - 1) * strides[0] + weight_w - 2 * padding[0] +
output_padding[0])`
+
+ Parameters
+ ----------
+ data : Tensor
+ The input data to the operator.
+
+ weight : Tensor
+ The weight tensor.
+
+ strides : Union[int, Tuple[int]]
+ The strides of convolution. It is required to have length 1.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding of convolution on both sides of inputs before convolution.
+ It is required to have length either 1 or 2.
+
+ output_padding : Union[int, Tuple[int, ...]], optional
+ Used to disambiguate the output shape.
+
+ dilation : Union[int, Tuple[int]]
+ Specifies the dilation rate to be used for dilated convolution.
+ It is required to have length either 1.
+
+ groups : int
+ Number of groups to split the input into for grouped convolution.
+ The number of input and output channels should be divisible by the
number of groups.
+
+ data_layout : str
+ Layout of the input.
+
+ kernel_layout : str
+ Layout of the weight.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ out_dtype : Optional[Union[str, DataType]]
+ Specifies the output data type for mixed precision conv2d.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ conv_out = _op.nn.conv1d_transpose(
+ data=x._expr,
+ weight=weight._expr,
+ strides=stride,
+ padding=padding,
+ output_padding=output_padding,
+ dilation=dilation,
+ groups=groups,
+ )
+ if bias is not None:
+ conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1]))
+
+ return _wrap_nested(conv_out, name)
+
+
def maximum(x1: Tensor, x2: Tensor, name: str = "maximum"):
"""Element-wise maximum
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 524472feb4..f3d248eab4 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -175,6 +175,32 @@ def test_conv1d():
assert_structural_equal(tvm_mod["forward"], forward, True)
+def test_conv1d_transpose():
+ # fmt: off
+ @R.function
+ def forward(x: R.Tensor((1, 3, 30), dtype="float32"), _io: R.Object,
weight: R.Tensor((3, 32, 3), dtype="float32"), bias: R.Tensor((32,),
dtype="float32")) -> R.Tuple(R.Tensor((1, 32, 32), dtype="float32"),
R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv1: R.Tensor((1, 32, 32), dtype="float32") =
R.nn.conv1d_transpose(x, weight, strides=[1], padding=[0, 0],
output_padding=[0], dilation=[1], groups=1, data_layout="NCW",
kernel_layout="IOW", out_layout="NCW", out_dtype="void")
+ lv2: R.Tensor((1, 32, 1), dtype="float32") = R.reshape(bias,
R.shape([1, 32, 1]))
+ conv1d_transpose: R.Tensor((1, 32, 32), dtype="float32") =
R.add(lv1, lv2)
+ gv1: R.Tuple(R.Tensor((1, 32, 32), dtype="float32"),
R.Tuple(R.Object)) = conv1d_transpose, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ mod = modules.ConvTranspose1D(3, 32, 3, bias=True)
+ tvm_mod, _ = mod.export_tvm(
+ spec={
+ "forward": {
+ "x": spec.Tensor([1, 3, 30], "float32"),
+ }
+ },
+ debug=True,
+ )
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_layer_norm():
@R.function
def forward(