This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f18e2976f9 [Unity][Frontend][NN] Conv2D NN module (#15586)
f18e2976f9 is described below
commit f18e2976f9288f963a5eb561808d18984f7bdb47
Author: Josh Fromm <[email protected]>
AuthorDate: Mon Aug 21 13:57:05 2023 -0700
[Unity][Frontend][NN] Conv2D NN module (#15586)
* Add nn conv2d module and op
* Fix lint
* Remove useless mod spec
* Fix typo
* Use more standard input name
* retrigger ci
---
python/tvm/relax/frontend/nn/modules.py | 59 ++++++++++++++++++++++++++
python/tvm/relax/frontend/nn/op.py | 58 +++++++++++++++++++++++++
tests/python/relax/test_frontend_nn_modules.py | 29 +++++++++++++
tests/python/relax/test_frontend_nn_op.py | 36 ++++++++++++++++
4 files changed, 182 insertions(+)
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index 03b47619f0..35cad1285c 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -108,6 +108,65 @@ class Linear(Module):
return x
+class Conv2D(Module):
+ """
+ Module for conv2d layer.
+ """
+
+ def __init__( # pylint: disable=too-many-arguments
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ dtype: Optional[str] = None,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+
+ self.weight = Parameter(
+ (
+ self.out_channels,
+ int(self.in_channels / self.groups),
+ self.kernel_size,
+ self.kernel_size,
+ ),
+ dtype,
+ )
+ if bias:
+ self.bias = Parameter((self.out_channels,), dtype)
+ else:
+ self.bias = None
+
+ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
+ """
+ Forward method for conv2d layer.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input tensor.
+
+ Returns
+ -------
+ ret : Tensor
+ The output tensor for the conv2d layer.
+ """
+ return op.conv2d(
+ x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups
+ )
+
+
class RMSNorm(Module):
"""
Module for rms norm layer.
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 06d5e883e5..8dea54c72f 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -179,6 +179,64 @@ def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str]
= None, name: str = "m
return _wrap_nested(_op.matmul(a._expr, b._expr, out_dtype=out_dtype),
name)
+def conv2d(
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor] = None,
+ stride: Optional[Union[int, Tuple]] = 1,
+ padding: Optional[Union[int, Tuple, str]] = 0,
+ dilation: Optional[Union[int, Tuple]] = 1,
+ groups: Optional[int] = 1,
+ name: str = "conv2d",
+) -> Tensor:
+ """Applies a 2D convolution over an input image composed of sevaral input
planes
+
+ Parameters
+ ----------
+ x : Tensor
+ Input tensor of shape [B, N, H, W]
+
+ weight : Tensor
+ Filters of shape [O, N/groups, kH, kW]
+
+ bias : Optional[Tensor]
+ Optional bias tensor of shape [O].
+
+ stride : Optional[Union[int, Tuple]]
+ The stride of the convolving kernel. Can be a single number
+ or tuple of (sH, sW).
+
+ padding : Optional[[Union[int, Tuple]]]
+ Implicit paddings on both sides of the input.
+
+ dilation : Optional[Union[int, Tuple]]
+ The spacing between kernel elements. Can be a single number of tuple
(dH, dW).
+
+ groups : Optional[int]
+ Split input into a number of groups.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result with shape [B, O, oH, oW].
+ """
+ conv_out = _op.nn.conv2d(
+ data=x._expr,
+ weight=weight._expr,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+ if bias is not None:
+ conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1]))
+
+ return _wrap_nested(conv_out, name)
+
+
def maximum(x1: Tensor, x2: Tensor, name: str = "maximum"):
"""Element-wise maximum
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 8302257648..f0a3cbb3f2 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -47,6 +47,35 @@ def test_linear():
assert_structural_equal(tvm_mod["forward"], forward, True)
+def test_conv2d():
+ @R.function
+ def forward(
+ x: R.Tensor((1, 3, 32, 32), dtype="float32"),
+ weight: R.Tensor((32, 3, 3, 3), dtype="float32"),
+ bias: R.Tensor((32,), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"),
R.Tuple(R.Object)):
+ with R.dataflow():
+ lv1: R.Tensor((1, 32, 30, 30), dtype="float32") = R.nn.conv2d(x,
weight)
+ lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias,
R.shape([1, 32, 1, 1]))
+ conv2d: R.Tensor((1, 32, 30, 30), dtype="float32") = R.add(lv1,
lv2)
+ gv1: R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"),
R.Tuple(R.Object)) = conv2d, (
+ _io,
+ )
+ R.output(gv1)
+ return gv1
+
+ mod = modules.Conv2D(3, 32, 3, bias=True)
+ tvm_mod, _ = mod.export_tvm(
+ spec={
+ "forward": {
+ "x": spec.Tensor([1, 3, 32, 32], "float32"),
+ }
+ }
+ )
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_rms_norm():
@R.function
def forward(
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index c23ea8a437..27d7e6d2ff 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -135,6 +135,42 @@ def test_datatype():
tvm.ir.assert_structural_equal(irmodule["test"], test)
+def test_image():
+ class Model(Module):
+ def test(self, x: Tensor, weight: Tensor, bias: Tensor):
+ conv2d_out = op.conv2d(x, weight, bias)
+ return conv2d_out
+
+ @R.function
+ def test(
+ x: R.Tensor((1, 3, 32, 32), dtype="float32"),
+ weight: R.Tensor((32, 3, 3, 3), dtype="float32"),
+ bias: R.Tensor((32,), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"),
R.Tuple(R.Object)):
+ with R.dataflow():
+ lv1: R.Tensor((1, 32, 30, 30), dtype="float32") = R.nn.conv2d(x,
weight)
+ lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias,
R.shape([1, 32, 1, 1]))
+ conv2d: R.Tensor((1, 32, 30, 30), dtype="float32") = R.add(lv1,
lv2)
+ gv1: R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"),
R.Tuple(R.Object)) = conv2d, (
+ _io,
+ )
+ R.output(gv1)
+ return gv1
+
+ m = Model()
+ irmodule, params = m.export_tvm(
+ spec={
+ "test": {
+ "x": spec.Tensor([1, 3, 32, 32], "float32"),
+ "weight": spec.Tensor([32, 3, 3, 3], "float32"),
+ "bias": spec.Tensor([32], "float32"),
+ }
+ }
+ )
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
def test_nn():
class Model(Module):
def test(self, x: Tensor, weight: Tensor, bias: Tensor):