This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7336379dea [Unity][Frontend][NN] Better support for dynamic 
convolutions (#16427)
7336379dea is described below

commit 7336379deacdc7ffff073f8173361877eab1dd03
Author: Josh Fromm <[email protected]>
AuthorDate: Thu Jan 18 16:42:40 2024 -0800

    [Unity][Frontend][NN] Better support for dynamic convolutions (#16427)
    
    * Allow cutlass to skip dynamic convolutions, allow more dynamism in 
nn.Conv2D
    
    * Formatting
---
 python/tvm/relax/backend/contrib/cutlass.py    |  5 ++++
 python/tvm/relax/frontend/nn/modules.py        |  8 +++++-
 tests/python/relax/test_codegen_cutlass.py     | 33 ++++++++++++++++++++---
 tests/python/relax/test_frontend_nn_modules.py | 37 ++++++++++++++++++++++++++
 4 files changed, 79 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index de7eb54b99..a611bee2bb 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -132,6 +132,11 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
     if not _check_residual(conv2d_call, context):
         return False
 
+    # Check if any dimensions are symbolic.
+    for dim in data.struct_info.shape.values:
+        if isinstance(dim, tvm.tir.Var):
+            return False
+
     # pylint: disable=invalid-name
     IC = data.struct_info.shape.values[3]
     OC = weight.struct_info.shape.values[0]
diff --git a/python/tvm/relax/frontend/nn/modules.py 
b/python/tvm/relax/frontend/nn/modules.py
index f1b785b51a..29b9c7fcca 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -235,10 +235,16 @@ class Conv2D(Module):
         self.dilation = dilation
         self.groups = groups
 
+        # Allow dynamic input channels.
+        if isinstance(self.in_channels, int):
+            in_channels = int(self.in_channels / self.groups)
+        else:
+            in_channels = tir.floordiv(self.in_channels, self.groups)
+
         self.weight = Parameter(
             (
                 self.out_channels,
-                int(self.in_channels / self.groups),
+                in_channels,
                 self.kernel_size,
                 self.kernel_size,
             ),
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 5ae462774d..11437f7d68 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -107,9 +107,7 @@ def build_and_run(mod, inputs_np, target, legalize=True, 
cuda_graph=False):
     return f(*inputs).numpy()
 
 
-def get_result_with_relax_cutlass_offload(
-    mod, *args, assert_all_bindings_fused=True, num_final_bindings=1
-):
+def build_cutlass(mod, assert_all_bindings_fused=True, num_final_bindings=1):
     mod = partition_for_cutlass(mod)
 
     if assert_all_bindings_fused:
@@ -117,7 +115,13 @@ def get_result_with_relax_cutlass_offload(
 
     codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, 
"find_first_valid": True}})
     mod = codegen_pass(mod)
+    return mod
 
+
+def get_result_with_relax_cutlass_offload(
+    mod, *args, assert_all_bindings_fused=True, num_final_bindings=1
+):
+    mod = build_cutlass(mod, assert_all_bindings_fused, num_final_bindings)
     return build_and_run(mod, args, "cuda")
 
 
@@ -269,6 +273,29 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, 
epilogue, residual_bloc
     tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
 
 
[email protected](
+    "data_shape, weight_shape, dtype",
+    [
+        # batch dynamism
+        ((T.Var("n", "int64"), 32, 32, 16), (32, 3, 3, 16), "float16"),
+        # channel dynamism
+        ((16, 32, 32, T.Var("c", "int64")), (32, 3, 3, T.Var("c", "int64")), 
"float16"),
+    ],
+)
+def test_conv2d_dynamic(data_shape, weight_shape, dtype):
+    # Create dynamic conv2d module.
+    mod = get_relax_conv2d_module(
+        data_shape,
+        weight_shape,
+        dtype,
+    )
+    # Attempt to offload to cutlass, should run without an error
+    # but not offload due to incompatibility.
+    mod = build_cutlass(mod)
+    # Check that no cutlass call is introduced (until we support dynamism).
+    assert "call_dps" not in str(mod.__repr__())
+
+
 def test_cutlass_partition_conv2d_residual_blocked():
     @tvm.script.ir_module
     class Conv2dReLU:
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index 61fe95bccb..f438f38705 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -26,6 +26,7 @@ from tvm.ir import assert_structural_equal
 from tvm.relax.frontend import nn
 from tvm.relax.frontend.nn import core, modules, spec
 from tvm.script import ir as I
+from tvm.script import tir as T
 from tvm.script import relax as R
 
 
@@ -245,6 +246,42 @@ def test_conv2d():
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
+def test_conv2d_dynamic():
+    @R.function
+    def forward(
+        x: R.Tensor(("n", "c", "h", "w"), dtype="float32"),
+        _io: R.Object,
+        weight: R.Tensor((32, "in_channels", 3, 3), dtype="float32"),
+        bias: R.Tensor((32,), dtype="float32"),
+    ) -> R.Tuple(R.Tensor(("n", 32, "h - 2", "w - 2"), dtype="float32"), 
R.Tuple(R.Object)):
+        n = T.int64()
+        h = T.int64()
+        w = T.int64()
+        c = T.int64()
+        in_channels = T.int64()
+        R.func_attr({"num_input": 2})
+        with R.dataflow():
+            lv1: R.Tensor((n, 32, h - 2, w - 2), dtype="float32") = 
R.nn.conv2d(x, weight)
+            lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias, 
R.shape([1, 32, 1, 1]))
+            conv2d: R.Tensor((n, 32, h - 2, w - 2), dtype="float32") = 
R.add(lv1, lv2)
+            gv1: R.Tuple(
+                R.Tensor((n, 32, h - 2, w - 2), dtype="float32"), 
R.Tuple(R.Object)
+            ) = conv2d, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.Conv2D(tvm.tir.Var("in_channels", "int64"), 32, 3, bias=True)
+    tvm_mod, _ = mod.export_tvm(
+        spec={
+            "forward": {
+                "x": spec.Tensor(["n", "c", "h", "w"], "float32"),
+            }
+        },
+        debug=True,
+    )
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
 def test_rms_norm():
     @R.function
     def forward(

Reply via email to