This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ec8683fb6 [MetaSchedule][Test] Add unittests for CBR (#12252)
4ec8683fb6 is described below

commit 4ec8683fb65ce741279891b2fb5c0437680fd685
Author: Junru Shao <[email protected]>
AuthorDate: Mon Aug 1 05:36:24 2022 -0700

    [MetaSchedule][Test] Add unittests for CBR (#12252)
---
 python/tvm/meta_schedule/testing/te_workload.py    |   2 +-
 .../unittest/test_meta_schedule_space_cpu.py       | 175 +++++++++++++++++++++
 .../unittest/test_meta_schedule_space_cuda.py      |  91 +++++++++++
 .../test_tir_analysis_estimate_tir_flops.py        |   2 +-
 4 files changed, 268 insertions(+), 2 deletions(-)

diff --git a/python/tvm/meta_schedule/testing/te_workload.py 
b/python/tvm/meta_schedule/testing/te_workload.py
index 0d1fc0a4d8..6fac1c2960 100644
--- a/python/tvm/meta_schedule/testing/te_workload.py
+++ b/python/tvm/meta_schedule/testing/te_workload.py
@@ -868,7 +868,7 @@ CONFIGS = {
             (2048, 2048),
         ],
     ),
-    "C2d-BN-RELU": (
+    "CBR": (
         conv2d_nhwc_bn_relu,
         [
             (1, 224, 224, 3, 64, 7, 2, 3),
diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py 
b/tests/python/unittest/test_meta_schedule_space_cpu.py
index 051ccfd5cf..e0d7b29c89 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -2244,6 +2244,180 @@ def test_cpu_sfm():
     )
 
 
+def test_cpu_cbr():
+    # fmt: off
+    @T.prim_func
+    def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: 
T.Buffer[(7, 7, 3, 64), "float32"], bias: T.Buffer[64, "float32"], bn_offset: 
T.Buffer[64, "float32"], bn_scale: T.Buffer[64, "float32"], compute: 
T.Buffer[(1, 112, 112, 64), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, 
"meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
+            Conv2dOutput = T.alloc_buffer([1, 112, 112, 64], dtype="float32")
+            for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, 
i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in 
T.grid(1, 2, 7, 1, 1, 2, 2, 32, 7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2):
+                with T.block("Conv2dOutput"):
+                    nn = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                    yy = T.axis.spatial(112, i1_0 * 56 + i1_1 * 28 + i1_2 * 28 
+ i1_3)
+                    xx = T.axis.spatial(112, i2_0 * 16 + i2_1 * 8 + i2_2 * 2 + 
i2_3)
+                    ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 2 + i3_2 * 2 + 
i3_3)
+                    ry = T.axis.reduce(7, i4_1 + i4_0)
+                    rx = T.axis.reduce(7, i5_0 + i5_1)
+                    rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
+                    T.reads(data[nn, yy * 2 + ry - 3, xx * 2 + rx - 3, rc], 
kernel[ry, rx, rc, ff])
+                    T.writes(Conv2dOutput[nn, yy, xx, ff])
+                    T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                    with T.init():
+                        Conv2dOutput[nn, yy, xx, ff] = T.float32(0)
+                    Conv2dOutput[nn, yy, xx, ff] = Conv2dOutput[nn, yy, xx, 
ff] + T.if_then_else(3 <= yy * 2 + ry and yy * 2 + ry < 227 and 3 <= xx * 2 + 
rx and xx * 2 + rx < 227, data[nn, yy * 2 + ry - 3, xx * 2 + rx - 3, rc], 
T.float32(0), dtype="float32") * kernel[ry, rx, rc, ff]
+            for i0, i1, i2, i3 in T.grid(1, 112, 112, 64):
+                with T.block("compute"):
+                    i0_4, i1_4, i2_4, i3_4 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(Conv2dOutput[i0_4, i1_4, i2_4, i3_4], bias[i3_4], 
bn_scale[i3_4], bn_offset[i3_4])
+                    T.writes(compute[i0_4, i1_4, i2_4, i3_4])
+                    compute[i0_4, i1_4, i2_4, i3_4] = 
T.max((Conv2dOutput[i0_4, i1_4, i2_4, i3_4] + bias[i3_4]) * bn_scale[i3_4] + 
bn_offset[i3_4], T.float32(0))
+    @T.prim_func
+    def cbr_1(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: 
T.Buffer[(7, 7, 3, 64), "float32"], bias: T.Buffer[64, "float32"], bn_offset: 
T.Buffer[64, "float32"], bn_scale: T.Buffer[64, "float32"], compute: 
T.Buffer[(1, 112, 112, 64), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, 
"meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
+            PaddedInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+            Conv2dOutput = T.alloc_buffer([1, 112, 112, 64], dtype="float32")
+            for i0_0, i1_0 in T.grid(1, 2):
+                for ax0, ax1, ax2, ax3 in T.grid(1, 117, 229, 3):
+                    with T.block("PaddedInput"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(230, i1_0 * 112 + ax1)
+                        i2 = T.axis.spatial(230, ax2)
+                        i3 = T.axis.spatial(3, ax3)
+                        T.reads(data[i0, i1 - 3, i2 - 3, i3])
+                        T.writes(PaddedInput[i0, i1, i2, i3])
+                        PaddedInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 
and i1 < 227 and 3 <= i2 and i2 < 227, data[i0, i1 - 3, i2 - 3, i3], 
T.float32(0), dtype="float32")
+                for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1 in T.grid(7, 1, 1, 2, 
2, 32):
+                    for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, 
i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 
2):
+                        with T.block("Conv2dOutput"):
+                            nn = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                            yy = T.axis.spatial(112, i1_0 * 56 + i1_1 * 28 + 
i1_2 * 28 + i1_3)
+                            xx = T.axis.spatial(112, i2_0 * 16 + i2_1 * 8 + 
i2_2 * 2 + i2_3)
+                            ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 2 + 
i3_2 * 2 + i3_3)
+                            ry = T.axis.reduce(7, i4_1 + i4_0)
+                            rx = T.axis.reduce(7, i5_0 + i5_1)
+                            rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
+                            T.reads(PaddedInput[nn, yy * 2 + ry, xx * 2 + rx, 
rc], kernel[ry, rx, rc, ff])
+                            T.writes(Conv2dOutput[nn, yy, xx, ff])
+                            
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                            with T.init():
+                                Conv2dOutput[nn, yy, xx, ff] = T.float32(0)
+                            Conv2dOutput[nn, yy, xx, ff] = Conv2dOutput[nn, 
yy, xx, ff] + PaddedInput[nn, yy * 2 + ry, xx * 2 + rx, rc] * kernel[ry, rx, 
rc, ff]
+                    for ax0, ax1, ax2, ax3 in T.grid(1, 28, 8, 2):
+                        with T.block("compute"):
+                            i0 = T.axis.spatial(1, ax0)
+                            i1 = T.axis.spatial(112, i1_0 * 56 + i1_1 * 28 + 
ax1)
+                            i2 = T.axis.spatial(112, i2_0 * 16 + i2_1 * 8 + 
ax2)
+                            i3 = T.axis.spatial(64, i3_1 * 2 + ax3)
+                            T.reads(Conv2dOutput[i0, i1, i2, i3], bias[i3], 
bn_scale[i3], bn_offset[i3])
+                            T.writes(compute[i0, i1, i2, i3])
+                            compute[i0, i1, i2, i3] = T.max((Conv2dOutput[i0, 
i1, i2, i3] + bias[i3]) * bn_scale[i3] + bn_offset[i3], T.float32(0))
+    @T.prim_func
+    def cbr_2(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: 
T.Buffer[(7, 7, 3, 64), "float32"], bias: T.Buffer[64, "float32"], bn_offset: 
T.Buffer[64, "float32"], bn_scale: T.Buffer[64, "float32"], compute: 
T.Buffer[(1, 112, 112, 64), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, 
"meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
+            PaddedInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+            Conv2dOutput = T.alloc_buffer([1, 112, 112, 64], dtype="float32")
+            for i0_0, i1_0 in T.grid(1, 2):
+                for ax0, ax1, ax2, ax3 in T.grid(1, 117, 229, 3):
+                    with T.block("PaddedInput"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(230, i1_0 * 112 + ax1)
+                        i2 = T.axis.spatial(230, ax2)
+                        i3 = T.axis.spatial(3, ax3)
+                        T.reads(data[i0, i1 - 3, i2 - 3, i3])
+                        T.writes(PaddedInput[i0, i1, i2, i3])
+                        PaddedInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 
and i1 < 227 and 3 <= i2 and i2 < 227, data[i0, i1 - 3, i2 - 3, i3], 
T.float32(0), dtype="float32")
+                for i2_0, i3_0 in T.grid(7, 1):
+                    for i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, 
i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 2, 2, 32, 7, 
7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2):
+                        with T.block("Conv2dOutput"):
+                            nn = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                            yy = T.axis.spatial(112, i1_0 * 56 + i1_1 * 28 + 
i1_2 * 28 + i1_3)
+                            xx = T.axis.spatial(112, i2_0 * 16 + i2_1 * 8 + 
i2_2 * 2 + i2_3)
+                            ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 2 + 
i3_2 * 2 + i3_3)
+                            ry = T.axis.reduce(7, i4_1 + i4_0)
+                            rx = T.axis.reduce(7, i5_0 + i5_1)
+                            rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
+                            T.reads(PaddedInput[nn, yy * 2 + ry, xx * 2 + rx, 
rc], kernel[ry, rx, rc, ff])
+                            T.writes(Conv2dOutput[nn, yy, xx, ff])
+                            
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                            with T.init():
+                                Conv2dOutput[nn, yy, xx, ff] = T.float32(0)
+                            Conv2dOutput[nn, yy, xx, ff] = Conv2dOutput[nn, 
yy, xx, ff] + PaddedInput[nn, yy * 2 + ry, xx * 2 + rx, rc] * kernel[ry, rx, 
rc, ff]
+                    for ax0, ax1, ax2, ax3 in T.grid(1, 56, 16, 64):
+                        with T.block("compute"):
+                            i0 = T.axis.spatial(1, ax0)
+                            i1 = T.axis.spatial(112, i1_0 * 56 + ax1)
+                            i2 = T.axis.spatial(112, i2_0 * 16 + ax2)
+                            i3 = T.axis.spatial(64, ax3)
+                            T.reads(Conv2dOutput[i0, i1, i2, i3], bias[i3], 
bn_scale[i3], bn_offset[i3])
+                            T.writes(compute[i0, i1, i2, i3])
+                            compute[i0, i1, i2, i3] = T.max((Conv2dOutput[i0, 
i1, i2, i3] + bias[i3]) * bn_scale[i3] + bn_offset[i3], T.float32(0))
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [2, 2, 1, 28]),
+        ("SamplePerfectTile", [7, 2, 4, 2]),
+        ("SamplePerfectTile", [1, 32, 1, 2]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [1, 3]),
+        ("SampleCategorical", 2),
+        ("SampleComputeLocation", -2),
+    ]
+    decision_1 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [2, 2, 1, 28]),
+        ("SamplePerfectTile", [7, 2, 4, 2]),
+        ("SamplePerfectTile", [1, 32, 1, 2]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [1, 3]),
+        ("SampleCategorical", 3),
+        ("SampleComputeLocation", 1),
+    ]
+    decision_2 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [2, 2, 1, 28]),
+        ("SamplePerfectTile", [7, 2, 4, 2]),
+        ("SamplePerfectTile", [1, 32, 1, 2]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [1, 3]),
+        ("SampleCategorical", 2),
+        ("SampleComputeLocation", 1),
+    ]
+    mod = create_te_workload("CBR", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[cbr_0, cbr_1, cbr_2],
+        expected_decisions=[decision_0, decision_1, decision_2],
+    )
+
+
 if __name__ == "__main__":
     test_cpu_c1d()
     test_cpu_c2d()
@@ -2256,3 +2430,4 @@ if __name__ == "__main__":
     test_cpu_t2d()
     test_cpu_nrm()
     test_cpu_sfm()
+    test_cpu_cbr()
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py 
b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 8ad8991919..ae4737a362 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -1127,6 +1127,96 @@ def test_cuda_sfm():
     )
 
 
+def test_cuda_cbr():
+    # fmt: off
+    @T.prim_func
+    def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: 
T.Buffer[(7, 7, 3, 64), "float32"], bias: T.Buffer[64, "float32"], bn_offset: 
T.Buffer[64, "float32"], bn_scale: T.Buffer[64, "float32"], compute: 
T.Buffer[(1, 112, 112, 64), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.unroll_explicit":512})
+            Conv2dOutput_local = T.alloc_buffer([1, 112, 112, 64], 
dtype="float32", scope="local")
+            PaddedInput_shared = T.alloc_buffer([1, 230, 230, 3], 
dtype="float32", scope="shared")
+            kernel_shared = T.alloc_buffer([7, 7, 3, 64], dtype="float32", 
scope="shared")
+            for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(14, 
thread="blockIdx.x"):
+                for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(4, 
thread="vthread.x"):
+                    for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(128, 
thread="threadIdx.x"):
+                        for i4_0, i5_0, i6_0 in T.grid(7, 1, 3):
+                            for ax0_ax1_ax2_ax3_fused in T.serial(8251):
+                                with T.block("PaddedInput_shared"):
+                                    v0 = T.axis.spatial(1, 0)
+                                    v1 = T.axis.spatial(230, 
ax0_ax1_ax2_ax3_fused // 37 + i4_0)
+                                    v2 = T.axis.spatial(230, 
i0_0_i1_0_i2_0_i3_0_fused // 2 * 32 + ax0_ax1_ax2_ax3_fused % 37)
+                                    v3 = T.axis.spatial(3, i6_0)
+                                    T.reads(data[v0, v1 - 3, v2 - 3, v3])
+                                    T.writes(PaddedInput_shared[v0, v1, v2, 
v3])
+                                    
T.block_attr({"meta_schedule.cooperative_fetch":1})
+                                    PaddedInput_shared[v0, v1, v2, v3] = 
T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, data[v0, v1 - 3, 
v2 - 3, v3], T.float32(0), dtype="float32")
+                            for ax0_ax1_ax2_ax3_fused in T.serial(224):
+                                with T.block("kernel_shared"):
+                                    v0 = T.axis.spatial(7, i4_0)
+                                    v1 = T.axis.spatial(7, 
ax0_ax1_ax2_ax3_fused // 32)
+                                    v2 = T.axis.spatial(3, i6_0)
+                                    v3 = T.axis.spatial(64, 
i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32)
+                                    T.reads(kernel[v0, v1, v2, v3])
+                                    T.writes(kernel_shared[v0, v1, v2, v3])
+                                    
T.block_attr({"meta_schedule.cooperative_fetch":1})
+                                    kernel_shared[v0, v1, v2, v3] = kernel[v0, 
v1, v2, v3]
+                            for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, 
i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 1, 1, 1, 1, 2, 1, 7, 
1, 1, 7, 1, 8):
+                                with T.block("Conv2dOutput"):
+                                    nn = T.axis.spatial(1, i0_3 + i0_4)
+                                    yy = T.axis.spatial(112, 
i0_1_i1_1_i2_1_i3_1_fused // 2 * 56 + i0_2_i1_2_i2_2_i3_2_fused // 16 * 7 + 
i1_3 * 7 + i1_4)
+                                    xx = T.axis.spatial(112, i2_4 + 
i0_0_i1_0_i2_0_i3_0_fused // 2 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i2_3)
+                                    ff = T.axis.spatial(64, 
i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 16 + i3_3 
* 8 + i3_4)
+                                    ry = T.axis.reduce(7, i4_0 + i4_1 + i4_2)
+                                    rx = T.axis.reduce(7, i5_0 * 7 + i5_1 * 7 
+ i5_2)
+                                    rc = T.axis.reduce(3, i6_1 + i6_2 + i6_0)
+                                    T.reads(PaddedInput_shared[nn, yy * 2 + 
ry, xx * 2 + rx, rc], kernel_shared[ry, rx, rc, ff])
+                                    T.writes(Conv2dOutput_local[nn, yy, xx, 
ff])
+                                    
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, 
"meta_schedule.thread_extent_low_inclusive":32, 
"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                    with T.init():
+                                        Conv2dOutput_local[nn, yy, xx, ff] = 
T.float32(0)
+                                    Conv2dOutput_local[nn, yy, xx, ff] = 
Conv2dOutput_local[nn, yy, xx, ff] + PaddedInput_shared[nn, yy * 2 + ry, xx * 2 
+ rx, rc] * kernel_shared[ry, rx, rc, ff]
+                        for ax0, ax1, ax2, ax3 in T.grid(1, 7, 1, 16):
+                            with T.block("Conv2dOutput_local"):
+                                v0 = T.axis.spatial(1, ax0)
+                                v1 = T.axis.spatial(112, 
i0_1_i1_1_i2_1_i3_1_fused // 2 * 56 + i0_2_i1_2_i2_2_i3_2_fused // 16 * 7 + ax1)
+                                v2 = T.axis.spatial(112, 
i0_0_i1_0_i2_0_i3_0_fused // 2 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + ax2)
+                                v3 = T.axis.spatial(64, 
i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 16 + ax3)
+                                T.reads(Conv2dOutput_local[v0, v1, v2, v3], 
bias[v3], bn_scale[v3], bn_offset[v3])
+                                T.writes(compute[v0, v1, v2, v3])
+                                compute[v0, v1, v2, v3] = 
T.max((Conv2dOutput_local[v0, v1, v2, v3] + bias[v3]) * bn_scale[v3] + 
bn_offset[v3], T.float32(0))
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 2, 8, 1, 7]),
+        ("SamplePerfectTile", [7, 1, 16, 1, 1]),
+        ("SamplePerfectTile", [2, 2, 1, 2, 8]),
+        ("SamplePerfectTile", [7, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 7]),
+        ("SamplePerfectTile", [3, 1, 1]),
+        ("SampleCategorical", 0),
+        ("SampleCategorical", 0),
+        ("SampleCategorical", 3),
+    ]
+    mod = create_te_workload("CBR", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[cbr_0],
+        expected_decisions=[decision_0],
+    )
+
+
 if __name__ == "__main__":
     test_cuda_c1d()
     test_cuda_c2d()
@@ -1139,3 +1229,4 @@ if __name__ == "__main__":
     test_cuda_t2d()
     test_cuda_nrm()
     test_cuda_sfm()
+    test_cuda_cbr()
diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py 
b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
index 1cba1a739c..68279043c6 100644
--- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
+++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
@@ -38,7 +38,7 @@ from tvm.tir.analysis import estimate_tir_flops
         ("GMM", 4194304),
         ("GRP", 28901376),
         ("T2D", 268435456),
-        ("C2d-BN-RELU", 239239168),
+        ("CBR", 239239168),
         ("TBG", 25165824),
         ("NRM", 131072),
         ("SFM", 262144),

Reply via email to