This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0f05116453 [Unity][MetaSchedule] Add the module_equality param for
tune_relax flow (#14537)
0f05116453 is described below
commit 0f051164532afb053ab7c25e891523d3e96359f1
Author: Icemist <[email protected]>
AuthorDate: Mon Apr 17 13:07:22 2023 +0300
[Unity][MetaSchedule] Add the module_equality param for tune_relax flow
(#14537)
---
python/tvm/meta_schedule/relax_integration.py | 41 +++-
src/relax/backend/task_extraction.cc | 78 ++++++--
src/relax/transform/meta_schedule.cc | 21 ++-
.../relax/test_meta_schedule_relax_integration.py | 210 +++++++++++++++++++++
tests/python/relax/test_relay_translator.py | 16 +-
5 files changed, 337 insertions(+), 29 deletions(-)
diff --git a/python/tvm/meta_schedule/relax_integration.py
b/python/tvm/meta_schedule/relax_integration.py
index 62f5865242..c776d64763 100644
--- a/python/tvm/meta_schedule/relax_integration.py
+++ b/python/tvm/meta_schedule/relax_integration.py
@@ -57,6 +57,7 @@ def extract_tasks(
mod: Union[IRModule, "relax.Function"],
target: Target,
params: Optional[Dict[str, NDArray]] = None,
+ module_equality: str = "structural",
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relax program.
@@ -66,6 +67,18 @@ def extract_tasks(
The module or function to tune
target : tvm.target.Target
The compilation target
+ params : Optional[Dict[str, tvm.runtime.NDArray]]
+ The associated parameters of the program
+ module_equality : Optional[str]
+ A string to specify the module equality testing and hashing method.
+ It must be one of the followings:
+ - "structural": Use StructuralEqual/Hash
+ - "ignore-ndarray": Same as "structural", but ignore ndarray raw
data during
+ equality testing and hashing.
+ - "anchor-block": Apply equality testing and hashing on the anchor
block extracted from a
+ given module. The "ignore-ndarray" varint is used
for the extracted
+ blocks or in case no anchor block is found.
+ For the definition of the anchor block, see
tir/analysis/analysis.py.
Returns
-------
@@ -83,7 +96,7 @@ def extract_tasks(
target = Target(target)
if params:
mod = BindParams("main", params)(mod)
- return list(_extract_task_func(mod, target))
+ return list(_extract_task_func(mod, target, module_equality))
def extracted_tasks_to_tune_contexts(
@@ -162,6 +175,7 @@ def tune_relax(
space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
strategy: SearchStrategy.SearchStrategyType = "evolutionary",
seed: Optional[int] = None,
+ module_equality: str = "structural",
) -> Database:
"""Tune a Relax program.
@@ -199,6 +213,16 @@ def tune_relax(
The search strategy to use
seed : Optional[int]
The random seed
+ module_equality : Optional[str]
+ A string to specify the module equality testing and hashing method.
+ It must be one of the followings:
+ - "structural": Use StructuralEqual/Hash
+ - "ignore-ndarray": Same as "structural", but ignore ndarray raw
data during
+ equality testing and hashing.
+ - "anchor-block": Apply equality testing and hashing on the anchor
block extracted from a
+ given module. The "ignore-ndarray" varint is used
for the extracted
+ blocks or in case no anchor block is found.
+ For the definition of the anchor block, see
tir/analysis/analysis.py.
Returns
-------
@@ -206,7 +230,7 @@ def tune_relax(
The database that contains the tuning records
"""
tasks, task_weights = extracted_tasks_to_tune_contexts(
- extracted_tasks=extract_tasks(mod, target, params),
+ extracted_tasks=extract_tasks(mod, target, params,
module_equality=module_equality),
work_dir=work_dir,
space=space,
strategy=strategy,
@@ -225,6 +249,7 @@ def tune_relax(
cost_model=cost_model,
measure_callbacks=measure_callbacks,
task_scheduler=task_scheduler,
+ module_equality=module_equality,
)
@@ -247,6 +272,7 @@ def _tune_relax(
space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
strategy: SearchStrategy.SearchStrategyType = "evolutionary",
seed: Optional[int] = None,
+ module_equality: str = "structural",
) -> Database:
"""Interface with tuning api to tune a Relax program.
@@ -284,6 +310,16 @@ def _tune_relax(
The search strategy to use
seed : Optional[int]
The random seed
+ module_equality : Optional[str]
+ A string to specify the module equality testing and hashing method.
+ It must be one of the followings:
+ - "structural": Use StructuralEqual/Hash
+ - "ignore-ndarray": Same as "structural", but ignore ndarray raw
data during
+ equality testing and hashing.
+ - "anchor-block": Apply equality testing and hashing on the anchor
block extracted from a
+ given module. The "ignore-ndarray" varint is used
for the extracted
+ blocks or in case no anchor block is found.
+ For the definition of the anchor block, see
tir/analysis/analysis.py.
Returns
-------
@@ -310,6 +346,7 @@ def _tune_relax(
space=space,
strategy=strategy,
seed=seed,
+ module_equality=module_equality,
)
# Return original IRModule
# This pass only makes optimization decision
diff --git a/src/relax/backend/task_extraction.cc
b/src/relax/backend/task_extraction.cc
index 5bd764c68e..ebaee74f47 100644
--- a/src/relax/backend/task_extraction.cc
+++ b/src/relax/backend/task_extraction.cc
@@ -22,12 +22,18 @@
#include <tvm/relax/expr_functor.h>
#include <tvm/target/target.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/module_equality.h"
namespace tvm {
namespace relax {
namespace backend {
-using tvm::meta_schedule::ExtractedTask;
+using meta_schedule::ExtractedTask;
+using meta_schedule::ModuleEqual;
+using meta_schedule::ModuleEquality;
+using meta_schedule::ModuleHash;
/*!
* \brief Extract the Meta-Schedule tuning task from a given IRModule.
@@ -42,22 +48,45 @@ using tvm::meta_schedule::ExtractedTask;
* Then we will have a ExtractedTask for all three functions, whose weight
* is 5 + 3 + 2 = 10.
*/
+class BlockCounter : public tir::StmtVisitor {
+ public:
+ static size_t GetBlockCount(const tir::PrimFunc& func) {
+ BlockCounter counter;
+ counter(func->body);
+ return counter.count;
+ }
+
+ private:
+ void VisitStmt_(const tir::BlockNode* op) final {
+ ++count;
+ StmtVisitor::VisitStmt_(op);
+ }
+ size_t count{0};
+};
+
class TaskExtractor : public ExprVisitor {
public:
- static Array<ExtractedTask> ExtractTask(IRModule mod, Target target) {
- TaskExtractor extractor(mod, target);
+ static Array<ExtractedTask> ExtractTask(IRModule mod, Target target, String
mod_eq_name) {
+ TaskExtractor extractor(mod, target, mod_eq_name);
// We go through each Relax function in the module.
for (const auto& kv : mod->functions) {
if (const auto* func = kv.second.as<FunctionNode>()) {
extractor(GetRef<Function>(func));
}
}
- return std::move(extractor.tasks_);
+ Array<ExtractedTask> tasks;
+ for (const auto& it : extractor.func2task_) {
+ tasks.push_back(it.second);
+ }
+ return tasks;
}
private:
- explicit TaskExtractor(IRModule mod, Target target)
- : mod_(std::move(mod)), target_(std::move(target)) {
+ explicit TaskExtractor(IRModule mod, Target target, String mod_eq_name)
+ : mod_(std::move(mod)),
+ target_(std::move(target)),
+ mod_eq_(ModuleEquality::Create(mod_eq_name)),
+ func2task_(/*bucket_count*/ 0, ModuleHash(*mod_eq_),
ModuleEqual(*mod_eq_)) {
normalize_mod_func_ =
runtime::Registry::Get("tvm.meta_schedule.normalize_mod");
ICHECK(normalize_mod_func_) << "Normalization function is not found.";
}
@@ -75,33 +104,44 @@ class TaskExtractor : public ExprVisitor {
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
const tir::PrimFunc& func =
Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
-
- auto it = func2task_.find(func);
+ IRModule mod = (*normalize_mod_func_)(func);
+ size_t weight = 1;
+ auto it = func2task_.find(mod);
if (it != func2task_.end()) {
it->second->weight += 1;
- return;
+ const tir::PrimFunc& alt_func =
Downcast<tir::PrimFunc>(it->first->Lookup("main"));
+ // When anchor-block based equality is used, tuning tasks
"nn_conv2d_add_nn_relu" and
+ // "nn_conv2d_add_add_nn_relu", for example, can be identified as equal.
Thus, one of them
+ // will be selected to tune by the code below.
+ //
+ // To make sure that we tune "nn_conv2d_add_nn_relu" and not
"nn_conv2d_add_add_nn_relu", we
+ // count the PrinFunc number of blocks and leave only the function with
the smallest number of
+ // blocks. This way, "nn_conv2d_add_nn_relu" will have a smaller number
of blocks than
+ // "nn_conv2d_add_add_nn_relu" and will be selected to tune.
+ if (BlockCounter::GetBlockCount(func) <
BlockCounter::GetBlockCount(alt_func)) {
+ weight += it->second->weight;
+ func2task_.erase(it->first);
+ }
}
- IRModule tir_mod = (*normalize_mod_func_)(func);
ExtractedTask task(/*task_name=*/global_var->name_hint, //
- /*mod=*/tir_mod, //
+ /*mod=*/mod, //
/*target=*/target_, //
- /*dispatched=*/{tir_mod}, //
- /*weight=*/1);
- tasks_.push_back(task);
- func2task_.emplace(func, task);
+ /*dispatched=*/{mod}, //
+ /*weight=*/weight);
+ func2task_.emplace(mod, task);
}
IRModule mod_;
Target target_;
- Array<ExtractedTask> tasks_;
- std::unordered_map<tir::PrimFunc, ExtractedTask, StructuralHash,
StructuralEqual> func2task_;
+ std::unique_ptr<ModuleEquality> mod_eq_;
+ std::unordered_map<IRModule, ExtractedTask, ModuleHash, ModuleEqual>
func2task_;
const runtime::PackedFunc* normalize_mod_func_;
};
TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask")
- .set_body_typed([](IRModule mod, Target target) {
- return TaskExtractor::ExtractTask(std::move(mod), std::move(target));
+ .set_body_typed([](IRModule mod, Target target, String mod_eq_name) {
+ return TaskExtractor::ExtractTask(std::move(mod), std::move(target),
std::move(mod_eq_name));
});
} // namespace backend
diff --git a/src/relax/transform/meta_schedule.cc
b/src/relax/transform/meta_schedule.cc
index e205e984df..c33a90ccb1 100644
--- a/src/relax/transform/meta_schedule.cc
+++ b/src/relax/transform/meta_schedule.cc
@@ -26,6 +26,9 @@
#include <tvm/relax/tuning_api.h>
#include <tvm/tir/transform.h>
+#include "../src/meta_schedule/module_equality.h"
+#include "../src/meta_schedule/trace_apply.h"
+
namespace tvm {
namespace relax {
namespace transform {
@@ -105,6 +108,7 @@ Pass MetaScheduleApplyDatabase(Optional<String> work_dir) {
}
Map<GlobalVar, BaseFunc> result;
+ auto mod_eq_structural =
meta_schedule::ModuleEquality::Create("ignore-ndarray");
for (const auto& iter : mod->functions) {
GlobalVar gv = iter.first;
BaseFunc base_func = iter.second;
@@ -112,8 +116,21 @@ Pass MetaScheduleApplyDatabase(Optional<String> work_dir) {
tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
IRModule tir_mod = (*normalize_mod_func_)(prim_func);
- if (Optional<tir::Schedule> sch = database->QuerySchedule(tir_mod,
target, gv->name_hint)) {
- IRModule new_mod = sch.value()->mod();
+ if (Optional<meta_schedule::TuningRecord> opt_record =
+ database->QueryTuningRecord(tir_mod, target, gv->name_hint)) {
+ meta_schedule::TuningRecord record = opt_record.value();
+ tir::Schedule sch =
+ tir::Schedule::Traced(tir_mod, /*seed=*/-1, /*debug_mask=*/0,
+
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
+ if (!mod_eq_structural->Equal(tir_mod, record->workload->mod)) {
+ // When the database lookup succeeds while structural equality
check fails,
+ // it implies that the anchor block based equality has been used
during tuning.
+ // The trace in the record cannot directly be applied to this
query module.
+ meta_schedule::ScheduleUsingAnchorTrace(sch, record->trace,
target);
+ } else {
+ record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false);
+ }
+ IRModule new_mod = sch->mod();
ICHECK_EQ(new_mod->functions.size(), 1);
BaseFunc new_base_func = (*new_mod->functions.begin()).second;
ICHECK(new_base_func->IsInstance<tir::PrimFuncNode>());
diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py
b/tests/python/relax/test_meta_schedule_relax_integration.py
new file mode 100644
index 0000000000..a66a29405b
--- /dev/null
+++ b/tests/python/relax/test_meta_schedule_relax_integration.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Integration test for MetaSchedule"""
+
+import numpy as np
+import pytest
+import tempfile
+import tvm
+import tvm.testing
+from tvm import IRModule
+from tvm import meta_schedule as ms
+from tvm import relax, tir
+from tvm.ir import transform
+
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.script import relax as R
+
+# fmt: off
[email protected]_module
+class Module0:
+ @R.function
+ def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8,
8, 4), dtype="int32"):
+ cls = Module0
+ with R.dataflow():
+ c =
R.const([[[[-171701247],[-1719837685],[1801664104],[-634316588]],[[920159370],[-132073802],[2142531563],[1465185701]],[[-1505608067],[1737948828],[1581089391],[-1986167320]]],[[[-1449581822],[35714587],[496324563],[-1430879015]],[[-1615680873],[1198514997],[1494683955],[1567376558]],[[1319924884],[-380548171],[296785437],[-1546305981]]],[[[-398644701],[-2004794585],[-1850413687],[2072643657]],[[847950121],[-544212073],[-199532669],[-343273682]],[[953721562],[-1930209358],
[...]
+ lv: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(data, c,
strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4,
data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
+ b = R.const([[[[1, 1, 1, 1]]]], "int32")
+ lv1: R.Tensor((1, 8, 8, 4), dtype="int32") = R.add(lv, b)
+ c1 =
R.const([[[[2042349344],[-2076067063],[1528163722],[-1156452837]],[[-2097172051],[1137787079],[-601389657],[1907495997]],[[987801941],[1073738593],[-1410339796],[-689755358]]],[[[90351522],[-44886952],[-1914103775],[-691553659]],[[-1288505112],[-1376578817],[-2067933148],[-1413101824]],[[1261422027],[-156976862],[-1185734459],[1608778622]]],[[[-664209483],[1907479806],[1838595152],[464942526]],[[877953160],[415131837],[-2010736511],[1218242769]],[[-1440127632],[112931],[
[...]
+ lv2: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(lv1, c1,
strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4,
data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
+ c2 =
R.const([[[[687940110],[-910571705],[-901609800],[-500525928]],[[506872399],[1070176297],[-305936110],[1625439784]],[[-1565626954],[-1705688881],[-866370805],[-1750740826]]],[[[300497007],[-626864803],[390295545],[222549121]],[[319224543],[-2003064970],[657992492],[2014175448]],[[653278589],[-768810984],[-294555581],[-1197167662]]],[[[1703154671],[-1540759805],[-568817430],[-1729755444]],[[-275458074],[2078945571],[1683298006],[-1029327874]],[[1315093181],[159010501],[87
[...]
+ lv3: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(lv2, c2,
strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4,
data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
+ gv: R.Tensor((1, 8, 8, 4), dtype="int32") = lv3
+ R.output(gv)
+ return gv
+
+# fmt: on
+
+# fmt: off
[email protected]_module
+class Module:
+ @T.prim_func
+ def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8),
T.int64(4)), "int32"), DepthwiseConv2d: T.Buffer((T.int64(1), T.int64(8),
T.int64(8), T.int64(4)), "int32")):
+ T.func_attr({"op_pattern": 4, "tir.noalias": True})
+ # with T.block("root"):
+ PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10),
T.int64(4)), "int32")
+ fused_constant = T.allocate_const([-171701247, -1719837685,
1801664104, -634316588, 920159370, -132073802, 2142531563, 1465185701,
-1505608067, 1737948828, 1581089391, -1986167320, -1449581822, 35714587,
496324563, -1430879015, -1615680873, 1198514997, 1494683955, 1567376558,
1319924884, -380548171, 296785437, -1546305981, -398644701, -2004794585,
-1850413687, 2072643657, 847950121, -544212073, -199532669, -343273682,
953721562, -1930209358, 1573600108, -577689853], "int32", [3, [...]
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10),
T.int64(4)):
+ with T.block("PaddedInput"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 -
T.int64(1), v_i3])
+ T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3])
+ PaddedInput[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2
and v_i2 < T.int64(9), rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 -
T.int64(1), v_i3], 0)
+ for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8),
T.int64(4), T.int64(3), T.int64(3)):
+ with T.block("DepthwiseConv2d"):
+ v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i,
j, c, di, dj])
+ fused_constant_1 = T.Buffer((3, 3, 4, 1), "int32",
data=fused_constant)
+ T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c],
fused_constant_1[v_di, v_dj, v_c, T.int64(0)])
+ T.writes(DepthwiseConv2d[v_b, v_i, v_j, v_c])
+ with T.init():
+ DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0
+ DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b,
v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] *
fused_constant_1[v_di, v_dj, v_c, T.int64(0)]
+
+ @T.prim_func
+ def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8),
T.int64(4)), "int32"), DepthwiseConv2d0: T.Buffer((T.int64(1), T.int64(8),
T.int64(8), T.int64(4)), "int32")):
+ T.func_attr({"op_pattern": 4, "tir.noalias": True})
+ # with T.block("root"):
+ PaddedInput0 = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10),
T.int64(4)), "int32")
+ fused_constant0 = T.allocate_const([2042349344, -2076067063,
1528163722, -1156452837, -2097172051, 1137787079, -601389657, 1907495997,
987801941, 1073738593, -1410339796, -689755358, 90351522, -44886952,
-1914103775, -691553659, -1288505112, -1376578817, -2067933148, -1413101824,
1261422027, -156976862, -1185734459, 1608778622, -664209483, 1907479806,
1838595152, 464942526, 877953160, 415131837, -2010736511, 1218242769,
-1440127632, 112931, 521745784, -1931145893], "int32", [3, 3 [...]
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10),
T.int64(4)):
+ with T.block("PaddedInput"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 -
T.int64(1), v_i3])
+ T.writes(PaddedInput0[v_i0, v_i1, v_i2, v_i3])
+ PaddedInput0[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2
and v_i2 < T.int64(9), rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 -
T.int64(1), v_i3], 0)
+ for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8),
T.int64(4), T.int64(3), T.int64(3)):
+ with T.block("DepthwiseConv2d"):
+ v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i,
j, c, di, dj])
+ fused_constant0_1 = T.Buffer((3, 3, 4, 1), "int32",
data=fused_constant0)
+ T.reads(PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c],
fused_constant0_1[v_di, v_dj, v_c, T.int64(0)])
+ T.writes(DepthwiseConv2d0[v_b, v_i, v_j, v_c])
+ with T.init():
+ DepthwiseConv2d0[v_b, v_i, v_j, v_c] = 0
+ DepthwiseConv2d0[v_b, v_i, v_j, v_c] = DepthwiseConv2d0[v_b,
v_i, v_j, v_c] + PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c] *
fused_constant0_1[v_di, v_dj, v_c, T.int64(0)]
+
+ @T.prim_func
+ def fused_conv2d_add(data: T.Buffer((T.int64(1), T.int64(8), T.int64(8),
T.int64(4)), "int32"), T_add: T.Buffer((T.int64(1), T.int64(8), T.int64(8),
T.int64(4)), "int32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10),
T.int64(4)), "int32")
+ DepthwiseConv2d = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(8),
T.int64(4)), "int32")
+ fused_nn_conv2d_constant = T.allocate_const([1, 1, 1, 1], "int32", [1,
1, 1, 4])
+ fused_constant_2 = T.allocate_const([687940110, -910571705,
-901609800, -500525928, 506872399, 1070176297, -305936110, 1625439784,
-1565626954, -1705688881, -866370805, -1750740826, 300497007, -626864803,
390295545, 222549121, 319224543, -2003064970, 657992492, 2014175448, 653278589,
-768810984, -294555581, -1197167662, 1703154671, -1540759805, -568817430,
-1729755444, -275458074, 2078945571, 1683298006, -1029327874, 1315093181,
159010501, 875694807, -223655381], "int32", [3, 3, 4, 1])
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10),
T.int64(4)):
+ with T.block("PaddedInput"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3])
+ T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3])
+ PaddedInput[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2
and v_i2 < T.int64(9), data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3],
0)
+ for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8),
T.int64(4), T.int64(3), T.int64(3)):
+ with T.block("DepthwiseConv2d"):
+ v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i,
j, c, di, dj])
+ fused_constant_2_1 = T.Buffer((3, 3, 4, 1), "int32",
data=fused_constant_2)
+ T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c],
fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)])
+ T.writes(DepthwiseConv2d[v_b, v_i, v_j, v_c])
+ with T.init():
+ DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0
+ DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b,
v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] *
fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(8), T.int64(8),
T.int64(4)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ fused_nn_conv2d_constant_1 = T.Buffer((1, 1, 1, 4), "int32",
data=fused_nn_conv2d_constant)
+ T.reads(DepthwiseConv2d[v_ax0, v_ax1, v_ax2, v_ax3],
fused_nn_conv2d_constant_1[v_ax0, T.int64(0), T.int64(0), v_ax3])
+ T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] = DepthwiseConv2d[v_ax0,
v_ax1, v_ax2, v_ax3] + fused_nn_conv2d_constant_1[v_ax0, T.int64(0),
T.int64(0), v_ax3]
+
+ @R.function
+ def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8,
8, 4), dtype="int32"):
+ cls = Module
+ with R.dataflow():
+ lv = R.call_tir(cls.fused_conv2d_add, data, out_sinfo=R.Tensor((1,
8, 8, 4), dtype="int32"))
+ lv2 = R.call_tir(cls.conv2d, lv, out_sinfo=R.Tensor((1, 8, 8, 4),
dtype="int32"))
+ lv3 = R.call_tir(cls.conv2d0, lv2, out_sinfo=R.Tensor((1, 8, 8,
4), dtype="int32"))
+ gv: R.Tensor((1, 8, 8, 4), dtype="int32") = lv3
+ R.output(gv)
+ return gv
+# fmt: on
+
+
+def test_extracting_tasks():
+ target = "llvm -mcpu=core-avx2 -num-cores=1"
+
+ relax_mod = Module0
+ relax_mod = relax.transform.LegalizeOps()(relax_mod)
+ relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod)
+ relax_mod = relax.transform.FuseOps()(relax_mod)
+ relax_mod = relax.transform.FoldConstant()(relax_mod)
+ relax_mod = relax.transform.FuseTIR()(relax_mod)
+
+ relax_expectation = {
+ "structural": 2, # The relax constants do not reach the tir at the
lowering.
+ "ignore-ndarray": 2,
+ "anchor-block": 1,
+ }
+ for module_equality, count in relax_expectation.items():
+ extracted_tasks = ms.relax_integration.extract_tasks(
+ relax_mod,
+ target,
+ {},
+ module_equality=module_equality,
+ )
+ assert len(extracted_tasks) == count
+
+ tir_relax_mod = Module
+ tir_relax_expectation = {"structural": 3, "ignore-ndarray": 2,
"anchor-block": 1}
+ for module_equality, count in tir_relax_expectation.items():
+ extracted_tasks = ms.relax_integration.extract_tasks(
+ tir_relax_mod,
+ target,
+ {},
+ module_equality=module_equality,
+ )
+ assert len(extracted_tasks) == count
+
+
[email protected]("module_equality", ["structural", "ignore-ndarray",
"anchor-block"])
+def test_using_anchor_trace(module_equality):
+ relax_mod = Module
+ target = "llvm -mcpu=core-avx2 -num-cores=1"
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ database = ms.relax_integration.tune_relax(
+ mod=relax_mod,
+ params={},
+ target=target,
+ work_dir=work_dir,
+ # for faster tuning
+ max_trials_global=100,
+ max_trials_per_task=4,
+ num_trials_per_iter=4,
+ strategy="replay-trace",
+ module_equality=module_equality,
+ seed=0,
+ )
+
+ ms.relax_integration.compile_relax(
+ database,
+ mod=relax_mod,
+ target=target,
+ params={},
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_relay_translator.py
b/tests/python/relax/test_relay_translator.py
index d3cd47b9e6..54cd1b243d 100644
--- a/tests/python/relax/test_relay_translator.py
+++ b/tests/python/relax/test_relay_translator.py
@@ -74,6 +74,7 @@ def relax_build_and_run(mod, target, dev, params, data):
db = ms.relax_integration.tune_relax(
mod=mod,
target=target,
+ params=params,
task_scheduler="round-robin",
num_trials_per_iter=32,
max_trials_per_task=32,
@@ -99,7 +100,6 @@ def verify_e2e_translation(target_str, layout, batch_size,
image_shape):
input_shape = (1, *image_shape)
data = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32), dev)
relax_mod = relay_translator.from_relay(relay_mod["main"], target, params)
- assert relax_mod["main"].attrs["global_symbol"] == "main"
_, _, relay_out = relay_build_and_run(relay_mod, target, dev, params, data)
_, _, relax_out = relax_build_and_run(relax_mod, target, dev, params, data)
@@ -123,7 +123,7 @@ def test_verify_e2e_translation_gpu(layout, batch_size,
image_shape):
verify_e2e_translation("cuda", layout, batch_size, image_shape)
-def verify_extracted_tasks(target_str, layout, batch_size, image_shape):
+def verify_extracted_tasks(target_str, layout, batch_size, image_shape,
module_equality):
target = Target(target_str)
relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape)
relax_mod = relay_translator.from_relay(
@@ -143,17 +143,20 @@ def verify_extracted_tasks(target_str, layout,
batch_size, image_shape):
"relay.backend.use_meta_schedule": True,
"relay.FuseOps.max_depth": 1, # Disable relay fusion
},
+ module_equality=module_equality,
)
relax_tasks = ms.relax_integration.extract_tasks(
relax_mod,
target=target,
params=params,
+ module_equality=module_equality,
)
# TODO (yongwww, yuchen): tophub guides relay passes, which causes
inconsistent tasks
# assert len(relay_tasks) == len(relax_tasks)
# TODO: Can we compare extracted tasks as well?
[email protected]("module_equality", ["structural", "ignore-ndarray",
"anchor-block"])
@pytest.mark.parametrize(
"layout, batch_size, image_shape",
[
@@ -161,16 +164,17 @@ def verify_extracted_tasks(target_str, layout,
batch_size, image_shape):
("NHWC", 1, (224, 224, 3)),
],
)
-def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape):
- verify_extracted_tasks("llvm --num-cores=16", layout, batch_size,
image_shape)
+def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape,
module_equality):
+ verify_extracted_tasks("llvm --num-cores=16", layout, batch_size,
image_shape, module_equality)
@tvm.testing.requires_gpu
[email protected]("module_equality", ["structural", "ignore-ndarray",
"anchor-block"])
@pytest.mark.parametrize(
"layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC",
1, (224, 224, 3))]
)
-def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape):
- verify_extracted_tasks("cuda", layout, batch_size, image_shape)
+def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape,
module_equality):
+ verify_extracted_tasks("cuda", layout, batch_size, image_shape,
module_equality)
def translate_and_build_vms(relay_mod, target_str="llvm",
translate_op_with_tir=None):