This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0f05116453 [Unity][MetaSchedule] Add the module_equality param for 
tune_relax flow (#14537)
0f05116453 is described below

commit 0f051164532afb053ab7c25e891523d3e96359f1
Author: Icemist <[email protected]>
AuthorDate: Mon Apr 17 13:07:22 2023 +0300

    [Unity][MetaSchedule] Add the module_equality param for tune_relax flow 
(#14537)
---
 python/tvm/meta_schedule/relax_integration.py      |  41 +++-
 src/relax/backend/task_extraction.cc               |  78 ++++++--
 src/relax/transform/meta_schedule.cc               |  21 ++-
 .../relax/test_meta_schedule_relax_integration.py  | 210 +++++++++++++++++++++
 tests/python/relax/test_relay_translator.py        |  16 +-
 5 files changed, 337 insertions(+), 29 deletions(-)

diff --git a/python/tvm/meta_schedule/relax_integration.py 
b/python/tvm/meta_schedule/relax_integration.py
index 62f5865242..c776d64763 100644
--- a/python/tvm/meta_schedule/relax_integration.py
+++ b/python/tvm/meta_schedule/relax_integration.py
@@ -57,6 +57,7 @@ def extract_tasks(
     mod: Union[IRModule, "relax.Function"],
     target: Target,
     params: Optional[Dict[str, NDArray]] = None,
+    module_equality: str = "structural",
 ) -> List[ExtractedTask]:
     """Extract tuning tasks from a relax program.
 
@@ -66,6 +67,18 @@ def extract_tasks(
         The module or function to tune
     target : tvm.target.Target
         The compilation target
+    params : Optional[Dict[str, tvm.runtime.NDArray]]
+        The associated parameters of the program
+    module_equality : Optional[str]
+        A string to specify the module equality testing and hashing method.
+        It must be one of the followings:
+          - "structural": Use StructuralEqual/Hash
+          - "ignore-ndarray": Same as "structural", but ignore ndarray raw 
data during
+                              equality testing and hashing.
+          - "anchor-block": Apply equality testing and hashing on the anchor 
block extracted from a
+                            given module. The "ignore-ndarray" varint is used 
for the extracted
+                            blocks or in case no anchor block is found.
+                            For the definition of the anchor block, see 
tir/analysis/analysis.py.
 
     Returns
     -------
@@ -83,7 +96,7 @@ def extract_tasks(
         target = Target(target)
     if params:
         mod = BindParams("main", params)(mod)
-    return list(_extract_task_func(mod, target))
+    return list(_extract_task_func(mod, target, module_equality))
 
 
 def extracted_tasks_to_tune_contexts(
@@ -162,6 +175,7 @@ def tune_relax(
     space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
     strategy: SearchStrategy.SearchStrategyType = "evolutionary",
     seed: Optional[int] = None,
+    module_equality: str = "structural",
 ) -> Database:
     """Tune a Relax program.
 
@@ -199,6 +213,16 @@ def tune_relax(
         The search strategy to use
     seed : Optional[int]
         The random seed
+    module_equality : Optional[str]
+        A string to specify the module equality testing and hashing method.
+        It must be one of the followings:
+          - "structural": Use StructuralEqual/Hash
+          - "ignore-ndarray": Same as "structural", but ignore ndarray raw 
data during
+                              equality testing and hashing.
+          - "anchor-block": Apply equality testing and hashing on the anchor 
block extracted from a
+                            given module. The "ignore-ndarray" varint is used 
for the extracted
+                            blocks or in case no anchor block is found.
+                            For the definition of the anchor block, see 
tir/analysis/analysis.py.
 
     Returns
     -------
@@ -206,7 +230,7 @@ def tune_relax(
         The database that contains the tuning records
     """
     tasks, task_weights = extracted_tasks_to_tune_contexts(
-        extracted_tasks=extract_tasks(mod, target, params),
+        extracted_tasks=extract_tasks(mod, target, params, 
module_equality=module_equality),
         work_dir=work_dir,
         space=space,
         strategy=strategy,
@@ -225,6 +249,7 @@ def tune_relax(
         cost_model=cost_model,
         measure_callbacks=measure_callbacks,
         task_scheduler=task_scheduler,
+        module_equality=module_equality,
     )
 
 
@@ -247,6 +272,7 @@ def _tune_relax(
     space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
     strategy: SearchStrategy.SearchStrategyType = "evolutionary",
     seed: Optional[int] = None,
+    module_equality: str = "structural",
 ) -> Database:
     """Interface with tuning api to tune a Relax program.
 
@@ -284,6 +310,16 @@ def _tune_relax(
         The search strategy to use
     seed : Optional[int]
         The random seed
+    module_equality : Optional[str]
+        A string to specify the module equality testing and hashing method.
+        It must be one of the followings:
+          - "structural": Use StructuralEqual/Hash
+          - "ignore-ndarray": Same as "structural", but ignore ndarray raw 
data during
+                              equality testing and hashing.
+          - "anchor-block": Apply equality testing and hashing on the anchor 
block extracted from a
+                            given module. The "ignore-ndarray" varint is used 
for the extracted
+                            blocks or in case no anchor block is found.
+                            For the definition of the anchor block, see 
tir/analysis/analysis.py.
 
     Returns
     -------
@@ -310,6 +346,7 @@ def _tune_relax(
         space=space,
         strategy=strategy,
         seed=seed,
+        module_equality=module_equality,
     )
     # Return original IRModule
     # This pass only makes optimization decision
diff --git a/src/relax/backend/task_extraction.cc 
b/src/relax/backend/task_extraction.cc
index 5bd764c68e..ebaee74f47 100644
--- a/src/relax/backend/task_extraction.cc
+++ b/src/relax/backend/task_extraction.cc
@@ -22,12 +22,18 @@
 #include <tvm/relax/expr_functor.h>
 #include <tvm/target/target.h>
 #include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/module_equality.h"
 
 namespace tvm {
 namespace relax {
 namespace backend {
 
-using tvm::meta_schedule::ExtractedTask;
+using meta_schedule::ExtractedTask;
+using meta_schedule::ModuleEqual;
+using meta_schedule::ModuleEquality;
+using meta_schedule::ModuleHash;
 
 /*!
  * \brief Extract the Meta-Schedule tuning task from a given IRModule.
@@ -42,22 +48,45 @@ using tvm::meta_schedule::ExtractedTask;
  *   Then we will have a ExtractedTask for all three functions, whose weight
  *   is 5 + 3 + 2 = 10.
  */
+class BlockCounter : public tir::StmtVisitor {
+ public:
+  static size_t GetBlockCount(const tir::PrimFunc& func) {
+    BlockCounter counter;
+    counter(func->body);
+    return counter.count;
+  }
+
+ private:
+  void VisitStmt_(const tir::BlockNode* op) final {
+    ++count;
+    StmtVisitor::VisitStmt_(op);
+  }
+  size_t count{0};
+};
+
 class TaskExtractor : public ExprVisitor {
  public:
-  static Array<ExtractedTask> ExtractTask(IRModule mod, Target target) {
-    TaskExtractor extractor(mod, target);
+  static Array<ExtractedTask> ExtractTask(IRModule mod, Target target, String 
mod_eq_name) {
+    TaskExtractor extractor(mod, target, mod_eq_name);
     // We go through each Relax function in the module.
     for (const auto& kv : mod->functions) {
       if (const auto* func = kv.second.as<FunctionNode>()) {
         extractor(GetRef<Function>(func));
       }
     }
-    return std::move(extractor.tasks_);
+    Array<ExtractedTask> tasks;
+    for (const auto& it : extractor.func2task_) {
+      tasks.push_back(it.second);
+    }
+    return tasks;
   }
 
  private:
-  explicit TaskExtractor(IRModule mod, Target target)
-      : mod_(std::move(mod)), target_(std::move(target)) {
+  explicit TaskExtractor(IRModule mod, Target target, String mod_eq_name)
+      : mod_(std::move(mod)),
+        target_(std::move(target)),
+        mod_eq_(ModuleEquality::Create(mod_eq_name)),
+        func2task_(/*bucket_count*/ 0, ModuleHash(*mod_eq_), 
ModuleEqual(*mod_eq_)) {
     normalize_mod_func_ = 
runtime::Registry::Get("tvm.meta_schedule.normalize_mod");
     ICHECK(normalize_mod_func_) << "Normalization function is not found.";
   }
@@ -75,33 +104,44 @@ class TaskExtractor : public ExprVisitor {
 
     const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
     const tir::PrimFunc& func = 
Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
-
-    auto it = func2task_.find(func);
+    IRModule mod = (*normalize_mod_func_)(func);
+    size_t weight = 1;
+    auto it = func2task_.find(mod);
     if (it != func2task_.end()) {
       it->second->weight += 1;
-      return;
+      const tir::PrimFunc& alt_func = 
Downcast<tir::PrimFunc>(it->first->Lookup("main"));
+      // When anchor-block based equality is used, tuning tasks 
"nn_conv2d_add_nn_relu" and
+      // "nn_conv2d_add_add_nn_relu", for example, can be identified as equal. 
Thus, one of them
+      // will be selected to tune by the code below.
+      //
+      // To make sure that we tune "nn_conv2d_add_nn_relu" and not 
"nn_conv2d_add_add_nn_relu", we
+      // count the PrinFunc number of blocks and leave only the function with 
the smallest number of
+      // blocks. This way, "nn_conv2d_add_nn_relu" will have a smaller number 
of blocks than
+      // "nn_conv2d_add_add_nn_relu" and will be selected to tune.
+      if (BlockCounter::GetBlockCount(func) < 
BlockCounter::GetBlockCount(alt_func)) {
+        weight += it->second->weight;
+        func2task_.erase(it->first);
+      }
     }
 
-    IRModule tir_mod = (*normalize_mod_func_)(func);
     ExtractedTask task(/*task_name=*/global_var->name_hint,  //
-                       /*mod=*/tir_mod,                      //
+                       /*mod=*/mod,                          //
                        /*target=*/target_,                   //
-                       /*dispatched=*/{tir_mod},             //
-                       /*weight=*/1);
-    tasks_.push_back(task);
-    func2task_.emplace(func, task);
+                       /*dispatched=*/{mod},                 //
+                       /*weight=*/weight);
+    func2task_.emplace(mod, task);
   }
 
   IRModule mod_;
   Target target_;
-  Array<ExtractedTask> tasks_;
-  std::unordered_map<tir::PrimFunc, ExtractedTask, StructuralHash, 
StructuralEqual> func2task_;
+  std::unique_ptr<ModuleEquality> mod_eq_;
+  std::unordered_map<IRModule, ExtractedTask, ModuleHash, ModuleEqual> 
func2task_;
   const runtime::PackedFunc* normalize_mod_func_;
 };
 
 TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask")
-    .set_body_typed([](IRModule mod, Target target) {
-      return TaskExtractor::ExtractTask(std::move(mod), std::move(target));
+    .set_body_typed([](IRModule mod, Target target, String mod_eq_name) {
+      return TaskExtractor::ExtractTask(std::move(mod), std::move(target), 
std::move(mod_eq_name));
     });
 
 }  // namespace backend
diff --git a/src/relax/transform/meta_schedule.cc 
b/src/relax/transform/meta_schedule.cc
index e205e984df..c33a90ccb1 100644
--- a/src/relax/transform/meta_schedule.cc
+++ b/src/relax/transform/meta_schedule.cc
@@ -26,6 +26,9 @@
 #include <tvm/relax/tuning_api.h>
 #include <tvm/tir/transform.h>
 
+#include "../src/meta_schedule/module_equality.h"
+#include "../src/meta_schedule/trace_apply.h"
+
 namespace tvm {
 namespace relax {
 namespace transform {
@@ -105,6 +108,7 @@ Pass MetaScheduleApplyDatabase(Optional<String> work_dir) {
     }
 
     Map<GlobalVar, BaseFunc> result;
+    auto mod_eq_structural = 
meta_schedule::ModuleEquality::Create("ignore-ndarray");
     for (const auto& iter : mod->functions) {
       GlobalVar gv = iter.first;
       BaseFunc base_func = iter.second;
@@ -112,8 +116,21 @@ Pass MetaScheduleApplyDatabase(Optional<String> work_dir) {
         tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
 
         IRModule tir_mod = (*normalize_mod_func_)(prim_func);
-        if (Optional<tir::Schedule> sch = database->QuerySchedule(tir_mod, 
target, gv->name_hint)) {
-          IRModule new_mod = sch.value()->mod();
+        if (Optional<meta_schedule::TuningRecord> opt_record =
+                database->QueryTuningRecord(tir_mod, target, gv->name_hint)) {
+          meta_schedule::TuningRecord record = opt_record.value();
+          tir::Schedule sch =
+              tir::Schedule::Traced(tir_mod, /*seed=*/-1, /*debug_mask=*/0,
+                                    
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
+          if (!mod_eq_structural->Equal(tir_mod, record->workload->mod)) {
+            // When the database lookup succeeds while structural equality 
check fails,
+            // it implies that the anchor block based equality has been used 
during tuning.
+            // The trace in the record cannot directly be applied to this 
query module.
+            meta_schedule::ScheduleUsingAnchorTrace(sch, record->trace, 
target);
+          } else {
+            record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false);
+          }
+          IRModule new_mod = sch->mod();
           ICHECK_EQ(new_mod->functions.size(), 1);
           BaseFunc new_base_func = (*new_mod->functions.begin()).second;
           ICHECK(new_base_func->IsInstance<tir::PrimFuncNode>());
diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py 
b/tests/python/relax/test_meta_schedule_relax_integration.py
new file mode 100644
index 0000000000..a66a29405b
--- /dev/null
+++ b/tests/python/relax/test_meta_schedule_relax_integration.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Integration test for MetaSchedule"""
+
+import numpy as np
+import pytest
+import tempfile
+import tvm
+import tvm.testing
+from tvm import IRModule
+from tvm import meta_schedule as ms
+from tvm import relax, tir
+from tvm.ir import transform
+
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.script import relax as R
+
+# fmt: off
[email protected]_module
+class Module0:
+    @R.function
+    def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8, 
8, 4), dtype="int32"):
+        cls = Module0
+        with R.dataflow():
+            c = 
R.const([[[[-171701247],[-1719837685],[1801664104],[-634316588]],[[920159370],[-132073802],[2142531563],[1465185701]],[[-1505608067],[1737948828],[1581089391],[-1986167320]]],[[[-1449581822],[35714587],[496324563],[-1430879015]],[[-1615680873],[1198514997],[1494683955],[1567376558]],[[1319924884],[-380548171],[296785437],[-1546305981]]],[[[-398644701],[-2004794585],[-1850413687],[2072643657]],[[847950121],[-544212073],[-199532669],[-343273682]],[[953721562],[-1930209358],
 [...]
+            lv: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(data, c, 
strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4, 
data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
+            b = R.const([[[[1, 1, 1, 1]]]], "int32")
+            lv1: R.Tensor((1, 8, 8, 4), dtype="int32") = R.add(lv, b)
+            c1 = 
R.const([[[[2042349344],[-2076067063],[1528163722],[-1156452837]],[[-2097172051],[1137787079],[-601389657],[1907495997]],[[987801941],[1073738593],[-1410339796],[-689755358]]],[[[90351522],[-44886952],[-1914103775],[-691553659]],[[-1288505112],[-1376578817],[-2067933148],[-1413101824]],[[1261422027],[-156976862],[-1185734459],[1608778622]]],[[[-664209483],[1907479806],[1838595152],[464942526]],[[877953160],[415131837],[-2010736511],[1218242769]],[[-1440127632],[112931],[
 [...]
+            lv2: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(lv1, c1, 
strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4, 
data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
+            c2 = 
R.const([[[[687940110],[-910571705],[-901609800],[-500525928]],[[506872399],[1070176297],[-305936110],[1625439784]],[[-1565626954],[-1705688881],[-866370805],[-1750740826]]],[[[300497007],[-626864803],[390295545],[222549121]],[[319224543],[-2003064970],[657992492],[2014175448]],[[653278589],[-768810984],[-294555581],[-1197167662]]],[[[1703154671],[-1540759805],[-568817430],[-1729755444]],[[-275458074],[2078945571],[1683298006],[-1029327874]],[[1315093181],[159010501],[87
 [...]
+            lv3: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(lv2, c2, 
strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4, 
data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
+            gv: R.Tensor((1, 8, 8, 4), dtype="int32") = lv3
+            R.output(gv)
+        return gv
+
+# fmt: on
+
+# fmt: off
[email protected]_module
+class Module:
+    @T.prim_func
+    def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), 
T.int64(4)), "int32"), DepthwiseConv2d: T.Buffer((T.int64(1), T.int64(8), 
T.int64(8), T.int64(4)), "int32")):
+        T.func_attr({"op_pattern": 4, "tir.noalias": True})
+        # with T.block("root"):
+        PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), 
T.int64(4)), "int32")
+        fused_constant = T.allocate_const([-171701247, -1719837685, 
1801664104, -634316588, 920159370, -132073802, 2142531563, 1465185701, 
-1505608067, 1737948828, 1581089391, -1986167320, -1449581822, 35714587, 
496324563, -1430879015, -1615680873, 1198514997, 1494683955, 1567376558, 
1319924884, -380548171, 296785437, -1546305981, -398644701, -2004794585, 
-1850413687, 2072643657, 847950121, -544212073, -199532669, -343273682, 
953721562, -1930209358, 1573600108, -577689853], "int32", [3,  [...]
+        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), 
T.int64(4)):
+            with T.block("PaddedInput"):
+                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 - 
T.int64(1), v_i3])
+                T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3])
+                PaddedInput[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 
and v_i2 < T.int64(9), rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 - 
T.int64(1), v_i3], 0)
+        for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), 
T.int64(4), T.int64(3), T.int64(3)):
+            with T.block("DepthwiseConv2d"):
+                v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, 
j, c, di, dj])
+                fused_constant_1 = T.Buffer((3, 3, 4, 1), "int32", 
data=fused_constant)
+                T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c], 
fused_constant_1[v_di, v_dj, v_c, T.int64(0)])
+                T.writes(DepthwiseConv2d[v_b, v_i, v_j, v_c])
+                with T.init():
+                    DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0
+                DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, 
v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * 
fused_constant_1[v_di, v_dj, v_c, T.int64(0)]
+
+    @T.prim_func
+    def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), 
T.int64(4)), "int32"), DepthwiseConv2d0: T.Buffer((T.int64(1), T.int64(8), 
T.int64(8), T.int64(4)), "int32")):
+        T.func_attr({"op_pattern": 4, "tir.noalias": True})
+        # with T.block("root"):
+        PaddedInput0 = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), 
T.int64(4)), "int32")
+        fused_constant0 = T.allocate_const([2042349344, -2076067063, 
1528163722, -1156452837, -2097172051, 1137787079, -601389657, 1907495997, 
987801941, 1073738593, -1410339796, -689755358, 90351522, -44886952, 
-1914103775, -691553659, -1288505112, -1376578817, -2067933148, -1413101824, 
1261422027, -156976862, -1185734459, 1608778622, -664209483, 1907479806, 
1838595152, 464942526, 877953160, 415131837, -2010736511, 1218242769, 
-1440127632, 112931, 521745784, -1931145893], "int32", [3, 3 [...]
+        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), 
T.int64(4)):
+            with T.block("PaddedInput"):
+                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 - 
T.int64(1), v_i3])
+                T.writes(PaddedInput0[v_i0, v_i1, v_i2, v_i3])
+                PaddedInput0[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 
and v_i2 < T.int64(9), rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 - 
T.int64(1), v_i3], 0)
+        for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), 
T.int64(4), T.int64(3), T.int64(3)):
+            with T.block("DepthwiseConv2d"):
+                v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, 
j, c, di, dj])
+                fused_constant0_1 = T.Buffer((3, 3, 4, 1), "int32", 
data=fused_constant0)
+                T.reads(PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c], 
fused_constant0_1[v_di, v_dj, v_c, T.int64(0)])
+                T.writes(DepthwiseConv2d0[v_b, v_i, v_j, v_c])
+                with T.init():
+                    DepthwiseConv2d0[v_b, v_i, v_j, v_c] = 0
+                DepthwiseConv2d0[v_b, v_i, v_j, v_c] = DepthwiseConv2d0[v_b, 
v_i, v_j, v_c] + PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c] * 
fused_constant0_1[v_di, v_dj, v_c, T.int64(0)]
+
+    @T.prim_func
+    def fused_conv2d_add(data: T.Buffer((T.int64(1), T.int64(8), T.int64(8), 
T.int64(4)), "int32"), T_add: T.Buffer((T.int64(1), T.int64(8), T.int64(8), 
T.int64(4)), "int32")):
+        T.func_attr({"tir.noalias": True})
+        # with T.block("root"):
+        PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), 
T.int64(4)), "int32")
+        DepthwiseConv2d = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(8), 
T.int64(4)), "int32")
+        fused_nn_conv2d_constant = T.allocate_const([1, 1, 1, 1], "int32", [1, 
1, 1, 4])
+        fused_constant_2 = T.allocate_const([687940110, -910571705, 
-901609800, -500525928, 506872399, 1070176297, -305936110, 1625439784, 
-1565626954, -1705688881, -866370805, -1750740826, 300497007, -626864803, 
390295545, 222549121, 319224543, -2003064970, 657992492, 2014175448, 653278589, 
-768810984, -294555581, -1197167662, 1703154671, -1540759805, -568817430, 
-1729755444, -275458074, 2078945571, 1683298006, -1029327874, 1315093181, 
159010501, 875694807, -223655381], "int32", [3, 3, 4, 1])
+        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), 
T.int64(4)):
+            with T.block("PaddedInput"):
+                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3])
+                T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3])
+                PaddedInput[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 
and v_i2 < T.int64(9), data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 
0)
+        for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), 
T.int64(4), T.int64(3), T.int64(3)):
+            with T.block("DepthwiseConv2d"):
+                v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, 
j, c, di, dj])
+                fused_constant_2_1 = T.Buffer((3, 3, 4, 1), "int32", 
data=fused_constant_2)
+                T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c], 
fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)])
+                T.writes(DepthwiseConv2d[v_b, v_i, v_j, v_c])
+                with T.init():
+                    DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0
+                DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, 
v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * 
fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)]
+        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(8), T.int64(8), 
T.int64(4)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
+                fused_nn_conv2d_constant_1 = T.Buffer((1, 1, 1, 4), "int32", 
data=fused_nn_conv2d_constant)
+                T.reads(DepthwiseConv2d[v_ax0, v_ax1, v_ax2, v_ax3], 
fused_nn_conv2d_constant_1[v_ax0, T.int64(0), T.int64(0), v_ax3])
+                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+                T_add[v_ax0, v_ax1, v_ax2, v_ax3] = DepthwiseConv2d[v_ax0, 
v_ax1, v_ax2, v_ax3] + fused_nn_conv2d_constant_1[v_ax0, T.int64(0), 
T.int64(0), v_ax3]
+
+    @R.function
+    def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8, 
8, 4), dtype="int32"):
+        cls = Module
+        with R.dataflow():
+            lv = R.call_tir(cls.fused_conv2d_add, data, out_sinfo=R.Tensor((1, 
8, 8, 4), dtype="int32"))
+            lv2 = R.call_tir(cls.conv2d, lv, out_sinfo=R.Tensor((1, 8, 8, 4), 
dtype="int32"))
+            lv3 = R.call_tir(cls.conv2d0, lv2, out_sinfo=R.Tensor((1, 8, 8, 
4), dtype="int32"))
+            gv: R.Tensor((1, 8, 8, 4), dtype="int32") = lv3
+            R.output(gv)
+        return gv
+# fmt: on
+
+
+def test_extracting_tasks():
+    target = "llvm -mcpu=core-avx2 -num-cores=1"
+
+    relax_mod = Module0
+    relax_mod = relax.transform.LegalizeOps()(relax_mod)
+    relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod)
+    relax_mod = relax.transform.FuseOps()(relax_mod)
+    relax_mod = relax.transform.FoldConstant()(relax_mod)
+    relax_mod = relax.transform.FuseTIR()(relax_mod)
+
+    relax_expectation = {
+        "structural": 2,  # The relax constants do not reach the tir at the 
lowering.
+        "ignore-ndarray": 2,
+        "anchor-block": 1,
+    }
+    for module_equality, count in relax_expectation.items():
+        extracted_tasks = ms.relax_integration.extract_tasks(
+            relax_mod,
+            target,
+            {},
+            module_equality=module_equality,
+        )
+        assert len(extracted_tasks) == count
+
+    tir_relax_mod = Module
+    tir_relax_expectation = {"structural": 3, "ignore-ndarray": 2, 
"anchor-block": 1}
+    for module_equality, count in tir_relax_expectation.items():
+        extracted_tasks = ms.relax_integration.extract_tasks(
+            tir_relax_mod,
+            target,
+            {},
+            module_equality=module_equality,
+        )
+        assert len(extracted_tasks) == count
+
+
[email protected]("module_equality", ["structural", "ignore-ndarray", 
"anchor-block"])
+def test_using_anchor_trace(module_equality):
+    relax_mod = Module
+    target = "llvm -mcpu=core-avx2 -num-cores=1"
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        database = ms.relax_integration.tune_relax(
+            mod=relax_mod,
+            params={},
+            target=target,
+            work_dir=work_dir,
+            # for faster tuning
+            max_trials_global=100,
+            max_trials_per_task=4,
+            num_trials_per_iter=4,
+            strategy="replay-trace",
+            module_equality=module_equality,
+            seed=0,
+        )
+
+    ms.relax_integration.compile_relax(
+        database,
+        mod=relax_mod,
+        target=target,
+        params={},
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_relay_translator.py 
b/tests/python/relax/test_relay_translator.py
index d3cd47b9e6..54cd1b243d 100644
--- a/tests/python/relax/test_relay_translator.py
+++ b/tests/python/relax/test_relay_translator.py
@@ -74,6 +74,7 @@ def relax_build_and_run(mod, target, dev, params, data):
         db = ms.relax_integration.tune_relax(
             mod=mod,
             target=target,
+            params=params,
             task_scheduler="round-robin",
             num_trials_per_iter=32,
             max_trials_per_task=32,
@@ -99,7 +100,6 @@ def verify_e2e_translation(target_str, layout, batch_size, 
image_shape):
     input_shape = (1, *image_shape)
     data = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32), dev)
     relax_mod = relay_translator.from_relay(relay_mod["main"], target, params)
-    assert relax_mod["main"].attrs["global_symbol"] == "main"
 
     _, _, relay_out = relay_build_and_run(relay_mod, target, dev, params, data)
     _, _, relax_out = relax_build_and_run(relax_mod, target, dev, params, data)
@@ -123,7 +123,7 @@ def test_verify_e2e_translation_gpu(layout, batch_size, 
image_shape):
     verify_e2e_translation("cuda", layout, batch_size, image_shape)
 
 
-def verify_extracted_tasks(target_str, layout, batch_size, image_shape):
+def verify_extracted_tasks(target_str, layout, batch_size, image_shape, 
module_equality):
     target = Target(target_str)
     relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape)
     relax_mod = relay_translator.from_relay(
@@ -143,17 +143,20 @@ def verify_extracted_tasks(target_str, layout, 
batch_size, image_shape):
             "relay.backend.use_meta_schedule": True,
             "relay.FuseOps.max_depth": 1,  # Disable relay fusion
         },
+        module_equality=module_equality,
     )
     relax_tasks = ms.relax_integration.extract_tasks(
         relax_mod,
         target=target,
         params=params,
+        module_equality=module_equality,
     )
     # TODO (yongwww, yuchen): tophub guides relay passes, which causes 
inconsistent tasks
     # assert len(relay_tasks) == len(relax_tasks)
     # TODO: Can we compare extracted tasks as well?
 
 
[email protected]("module_equality", ["structural", "ignore-ndarray", 
"anchor-block"])
 @pytest.mark.parametrize(
     "layout, batch_size, image_shape",
     [
@@ -161,16 +164,17 @@ def verify_extracted_tasks(target_str, layout, 
batch_size, image_shape):
         ("NHWC", 1, (224, 224, 3)),
     ],
 )
-def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape):
-    verify_extracted_tasks("llvm --num-cores=16", layout, batch_size, 
image_shape)
+def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape, 
module_equality):
+    verify_extracted_tasks("llvm --num-cores=16", layout, batch_size, 
image_shape, module_equality)
 
 
 @tvm.testing.requires_gpu
[email protected]("module_equality", ["structural", "ignore-ndarray", 
"anchor-block"])
 @pytest.mark.parametrize(
     "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 
1, (224, 224, 3))]
 )
-def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape):
-    verify_extracted_tasks("cuda", layout, batch_size, image_shape)
+def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape, 
module_equality):
+    verify_extracted_tasks("cuda", layout, batch_size, image_shape, 
module_equality)
 
 
 def translate_and_build_vms(relay_mod, target_str="llvm", 
translate_op_with_tir=None):

Reply via email to