This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7336379dea [Unity][Frontend][NN] Better support for dynamic
convolutions (#16427)
7336379dea is described below
commit 7336379deacdc7ffff073f8173361877eab1dd03
Author: Josh Fromm <[email protected]>
AuthorDate: Thu Jan 18 16:42:40 2024 -0800
[Unity][Frontend][NN] Better support for dynamic convolutions (#16427)
* Allow cutlass to skip dynamic convolutions, allow more dynamism in
nn.Conv2D
* Formatting
---
python/tvm/relax/backend/contrib/cutlass.py | 5 ++++
python/tvm/relax/frontend/nn/modules.py | 8 +++++-
tests/python/relax/test_codegen_cutlass.py | 33 ++++++++++++++++++++---
tests/python/relax/test_frontend_nn_modules.py | 37 ++++++++++++++++++++++++++
4 files changed, 79 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index de7eb54b99..a611bee2bb 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -132,6 +132,11 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
if not _check_residual(conv2d_call, context):
return False
+ # Check if any dimensions are symbolic.
+ for dim in data.struct_info.shape.values:
+ if isinstance(dim, tvm.tir.Var):
+ return False
+
# pylint: disable=invalid-name
IC = data.struct_info.shape.values[3]
OC = weight.struct_info.shape.values[0]
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index f1b785b51a..29b9c7fcca 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -235,10 +235,16 @@ class Conv2D(Module):
self.dilation = dilation
self.groups = groups
+ # Allow dynamic input channels.
+ if isinstance(self.in_channels, int):
+ in_channels = int(self.in_channels / self.groups)
+ else:
+ in_channels = tir.floordiv(self.in_channels, self.groups)
+
self.weight = Parameter(
(
self.out_channels,
- int(self.in_channels / self.groups),
+ in_channels,
self.kernel_size,
self.kernel_size,
),
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 5ae462774d..11437f7d68 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -107,9 +107,7 @@ def build_and_run(mod, inputs_np, target, legalize=True,
cuda_graph=False):
return f(*inputs).numpy()
-def get_result_with_relax_cutlass_offload(
- mod, *args, assert_all_bindings_fused=True, num_final_bindings=1
-):
+def build_cutlass(mod, assert_all_bindings_fused=True, num_final_bindings=1):
mod = partition_for_cutlass(mod)
if assert_all_bindings_fused:
@@ -117,7 +115,13 @@ def get_result_with_relax_cutlass_offload(
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80,
"find_first_valid": True}})
mod = codegen_pass(mod)
+ return mod
+
+def get_result_with_relax_cutlass_offload(
+ mod, *args, assert_all_bindings_fused=True, num_final_bindings=1
+):
+ mod = build_cutlass(mod, assert_all_bindings_fused, num_final_bindings)
return build_and_run(mod, args, "cuda")
@@ -269,6 +273,29 @@ def test_conv2d_offload(data_shape, weight_shape, dtype,
epilogue, residual_bloc
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
[email protected](
+ "data_shape, weight_shape, dtype",
+ [
+ # batch dynamism
+ ((T.Var("n", "int64"), 32, 32, 16), (32, 3, 3, 16), "float16"),
+ # channel dynamism
+ ((16, 32, 32, T.Var("c", "int64")), (32, 3, 3, T.Var("c", "int64")),
"float16"),
+ ],
+)
+def test_conv2d_dynamic(data_shape, weight_shape, dtype):
+ # Create dynamic conv2d module.
+ mod = get_relax_conv2d_module(
+ data_shape,
+ weight_shape,
+ dtype,
+ )
+ # Attempt to offload to cutlass, should run without an error
+ # but not offload due to incompatibility.
+ mod = build_cutlass(mod)
+ # Check that no cutlass call is introduced (until we support dynamism).
+ assert "call_dps" not in str(mod.__repr__())
+
+
def test_cutlass_partition_conv2d_residual_blocked():
@tvm.script.ir_module
class Conv2dReLU:
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 61fe95bccb..f438f38705 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -26,6 +26,7 @@ from tvm.ir import assert_structural_equal
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import core, modules, spec
from tvm.script import ir as I
+from tvm.script import tir as T
from tvm.script import relax as R
@@ -245,6 +246,42 @@ def test_conv2d():
assert_structural_equal(tvm_mod["forward"], forward, True)
+def test_conv2d_dynamic():
+ @R.function
+ def forward(
+ x: R.Tensor(("n", "c", "h", "w"), dtype="float32"),
+ _io: R.Object,
+ weight: R.Tensor((32, "in_channels", 3, 3), dtype="float32"),
+ bias: R.Tensor((32,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor(("n", 32, "h - 2", "w - 2"), dtype="float32"),
R.Tuple(R.Object)):
+ n = T.int64()
+ h = T.int64()
+ w = T.int64()
+ c = T.int64()
+ in_channels = T.int64()
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv1: R.Tensor((n, 32, h - 2, w - 2), dtype="float32") =
R.nn.conv2d(x, weight)
+ lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias,
R.shape([1, 32, 1, 1]))
+ conv2d: R.Tensor((n, 32, h - 2, w - 2), dtype="float32") =
R.add(lv1, lv2)
+ gv1: R.Tuple(
+ R.Tensor((n, 32, h - 2, w - 2), dtype="float32"),
R.Tuple(R.Object)
+ ) = conv2d, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.Conv2D(tvm.tir.Var("in_channels", "int64"), 32, 3, bias=True)
+ tvm_mod, _ = mod.export_tvm(
+ spec={
+ "forward": {
+ "x": spec.Tensor(["n", "c", "h", "w"], "float32"),
+ }
+ },
+ debug=True,
+ )
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_rms_norm():
@R.function
def forward(