cmpute commented on issue #18481:
URL: https://github.com/apache/tvm/issues/18481#issuecomment-3596203740
@ConvolutedDog Thanks for pointing out this. I indeed ran the script with
only 100 trials. However when I changed to 1200 trials and make each task has
been executed, there are still simlar errors:
<details>
```
18 |
fused_conv2d7_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8 |
13045760 | 1 | 190.3835 | 68.5236 | 68.5236 |
64 | Y
19 |
fused_conv2d8_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4
| 115730944 | 1 | 253.3194 | 456.8578 | 456.8578 |
64 | Y
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 1216
Total latency (us): 10629.9
[19:56:36]
/opt/mlc-llm/3rdparty/tvm/src/relax/transform/meta_schedule.cc:91: Warning:
Creating JSONDatabase. Workload at: tuning_logs/database_workload.json, Tuning
records at: tuning_logs/database_tuning_record.json
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((1, 3, 224, 224), dtype="float32"), p_conv1_weight:
R.Tensor((64, 3, 7, 7), dtype="float32"), p_bn1_weight: R.Tensor((64,),
dtype="float32"), p_bn1_bias: R.Tensor((64,), dtype="float32"),
p_layer1_0_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
p_layer1_0_bn1_weight: R.Tensor((64,), dtype="float32"), p_layer1_0_bn1_bias:
R.Tensor((64,), dtype="float32"), p_layer1_0_conv2_weight: R.Tensor((64, 64, 3,
3), dtype="float32"), p_layer1_0_bn2_weight: R.Tensor((64,), dtype="float32"),
p_layer1_0_bn2_bias: R.Tensor((64,), dtype="float32"), p_layer1_1_conv1_weight:
R.Tensor((64, 64, 3, 3), dtype="float32"), p_layer1_1_bn1_weight:
R.Tensor((64,), dtype="float32"), p_layer1_1_bn1_bias: R.Tensor((64,),
dtype="float32"), p_layer1_1_conv2_weight: R.Tensor((64, 64, 3, 3),
dtype="float32"), p_layer1_1_bn2_weight: R.Tensor((64,), dtype="float32"),
p_layer1_1_bn2_bias: R.Tensor((64,), dtype="float32"), p_layer2_0_conv1_weight:
R.Tensor((128, 64, 3, 3), dtype="float32"
), p_layer2_0_bn1_weight: R.Tensor((128,), dtype="float32"),
p_layer2_0_bn1_bias: R.Tensor((128,), dtype="float32"),
p_layer2_0_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
p_layer2_0_bn2_weight: R.Tensor((128,), dtype="float32"), p_layer2_0_bn2_bias:
R.Tensor((128,), dtype="float32"), p_layer2_0_downsample_0_weight:
R.Tensor((128, 64, 1, 1), dtype="float32"), p_layer2_0_downsample_1_weight:
R.Tensor((128,), dtype="float32"), p_layer2_0_downsample_1_bias:
R.Tensor((128,), dtype="float32"), p_layer2_1_conv1_weight: R.Tensor((128, 128,
3, 3), dtype="float32"), p_layer2_1_bn1_weight: R.Tensor((128,),
dtype="float32"), p_layer2_1_bn1_bias: R.Tensor((128,), dtype="float32"),
p_layer2_1_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
p_layer2_1_bn2_weight: R.Tensor((128,), dtype="float32"), p_layer2_1_bn2_bias:
R.Tensor((128,), dtype="float32"), p_layer3_0_conv1_weight: R.Tensor((256, 128,
3, 3), dtype="float32"), p_layer3_0_bn1_weight: R.Tensor((256,), dtype="flo
at32"), p_layer3_0_bn1_bias: R.Tensor((256,), dtype="float32"),
p_layer3_0_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
p_layer3_0_bn2_weight: R.Tensor((256,), dtype="float32"), p_layer3_0_bn2_bias:
R.Tensor((256,), dtype="float32"), p_layer3_0_downsample_0_weight:
R.Tensor((256, 128, 1, 1), dtype="float32"), p_layer3_0_downsample_1_weight:
R.Tensor((256,), dtype="float32"), p_layer3_0_downsample_1_bias:
R.Tensor((256,), dtype="float32"), p_layer3_1_conv1_weight: R.Tensor((256, 256,
3, 3), dtype="float32"), p_layer3_1_bn1_weight: R.Tensor((256,),
dtype="float32"), p_layer3_1_bn1_bias: R.Tensor((256,), dtype="float32"),
p_layer3_1_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
p_layer3_1_bn2_weight: R.Tensor((256,), dtype="float32"), p_layer3_1_bn2_bias:
R.Tensor((256,), dtype="float32"), p_layer4_0_conv1_weight: R.Tensor((512, 256,
3, 3), dtype="float32"), p_layer4_0_bn1_weight: R.Tensor((512,),
dtype="float32"), p_layer4_0_bn1_bias: R.Tensor((512,), dtype=
"float32"), p_layer4_0_conv2_weight: R.Tensor((512, 512, 3, 3),
dtype="float32"), p_layer4_0_bn2_weight: R.Tensor((512,), dtype="float32"),
p_layer4_0_bn2_bias: R.Tensor((512,), dtype="float32"),
p_layer4_0_downsample_0_weight: R.Tensor((512, 256, 1, 1), dtype="float32"),
p_layer4_0_downsample_1_weight: R.Tensor((512,), dtype="float32"),
p_layer4_0_downsample_1_bias: R.Tensor((512,), dtype="float32"),
p_layer4_1_conv1_weight: R.Tensor((512, 512, 3, 3), dtype="float32"),
p_layer4_1_bn1_weight: R.Tensor((512,), dtype="float32"), p_layer4_1_bn1_bias:
R.Tensor((512,), dtype="float32"), p_layer4_1_conv2_weight: R.Tensor((512, 512,
3, 3), dtype="float32"), p_layer4_1_bn2_weight: R.Tensor((512,),
dtype="float32"), p_layer4_1_bn2_bias: R.Tensor((512,), dtype="float32"),
p_fc_weight: R.Tensor((1000, 512), dtype="float32"), p_fc_bias:
R.Tensor((1000,), dtype="float32")) -> R.Tuple(R.Tensor((1, 1000),
dtype="float32")):
R.func_attr({"num_input": 1})
with R.dataflow():
lv =
R.call_tir(fused_conv2d_subtract_divide_expand_dims_multiply_expand_dims_add1_relu,
(x, p_conv1_weight, metadata["relax.expr.Constant"][0],
metadata["relax.expr.Constant"][1], p_bn1_weight, p_bn1_bias),
out_sinfo=R.Tensor((1, 64, 112, 112), dtype="float32"))
lv4 = R.call_tir(max_pool2d, (lv,), out_sinfo=R.Tensor((1, 64, 56,
56), dtype="float32"))
lv1 =
R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_relu1,
(lv4, p_layer1_0_conv1_weight, metadata["relax.expr.Constant"][2],
metadata["relax.expr.Constant"][3], p_layer1_0_bn1_weight,
p_layer1_0_bn1_bias), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
lv2 =
R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1,
(lv1, p_layer1_0_conv2_weight, metadata["relax.expr.Constant"][4],
metadata["relax.expr.Constant"][5], p_layer1_0_bn2_weight, p_layer1_0_bn2_bias,
lv4), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
lv3 =
R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_relu1,
(lv2, p_layer1_1_conv1_weight, metadata["relax.expr.Constant"][6],
metadata["relax.expr.Constant"][7], p_layer1_1_bn1_weight,
p_layer1_1_bn1_bias), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
lv4_1 =
R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1,
(lv3, p_layer1_1_conv2_weight, metadata["relax.expr.Constant"][8],
metadata["relax.expr.Constant"][9], p_layer1_1_bn2_weight, p_layer1_1_bn2_bias,
lv2), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
lv5 =
R.call_tir(fused_conv2d2_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2,
(lv4_1, p_layer2_0_conv1_weight, metadata["relax.expr.Constant"][10],
metadata["relax.expr.Constant"][11], p_layer2_0_bn1_weight,
p_layer2_0_bn1_bias), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
lv6 =
R.call_tir(fused_conv2d4_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5,
(lv4_1, p_layer2_0_downsample_0_weight, metadata["relax.expr.Constant"][12],
metadata["relax.expr.Constant"][13], p_layer2_0_downsample_1_weight,
p_layer2_0_downsample_1_bias), out_sinfo=R.Tensor((1, 128, 28, 28),
dtype="float32"))
lv7 =
R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_add6_relu2,
(lv5, p_layer2_0_conv2_weight, metadata["relax.expr.Constant"][14],
metadata["relax.expr.Constant"][15], p_layer2_0_bn2_weight,
p_layer2_0_bn2_bias, lv6), out_sinfo=R.Tensor((1, 128, 28, 28),
dtype="float32"))
lv8 =
R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2,
(lv7, p_layer2_1_conv1_weight, metadata["relax.expr.Constant"][16],
metadata["relax.expr.Constant"][17], p_layer2_1_bn1_weight,
p_layer2_1_bn1_bias), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
lv9 =
R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_add6_relu2,
(lv8, p_layer2_1_conv2_weight, metadata["relax.expr.Constant"][18],
metadata["relax.expr.Constant"][19], p_layer2_1_bn2_weight,
p_layer2_1_bn2_bias, lv7), out_sinfo=R.Tensor((1, 128, 28, 28),
dtype="float32"))
lv10 =
R.call_tir(fused_conv2d5_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3,
(lv9, p_layer3_0_conv1_weight, metadata["relax.expr.Constant"][20],
metadata["relax.expr.Constant"][21], p_layer3_0_bn1_weight,
p_layer3_0_bn1_bias), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
lv11 =
R.call_tir(fused_conv2d7_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8,
(lv9, p_layer3_0_downsample_0_weight, metadata["relax.expr.Constant"][22],
metadata["relax.expr.Constant"][23], p_layer3_0_downsample_1_weight,
p_layer3_0_downsample_1_bias), out_sinfo=R.Tensor((1, 256, 14, 14),
dtype="float32"))
lv12 =
R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_add9_relu3,
(lv10, p_layer3_0_conv2_weight, metadata["relax.expr.Constant"][24],
metadata["relax.expr.Constant"][25], p_layer3_0_bn2_weight,
p_layer3_0_bn2_bias, lv11), out_sinfo=R.Tensor((1, 256, 14, 14),
dtype="float32"))
lv13 =
R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3,
(lv12, p_layer3_1_conv1_weight, metadata["relax.expr.Constant"][26],
metadata["relax.expr.Constant"][27], p_layer3_1_bn1_weight,
p_layer3_1_bn1_bias), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
lv14 =
R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_add9_relu3,
(lv13, p_layer3_1_conv2_weight, metadata["relax.expr.Constant"][28],
metadata["relax.expr.Constant"][29], p_layer3_1_bn2_weight,
p_layer3_1_bn2_bias, lv12), out_sinfo=R.Tensor((1, 256, 14, 14),
dtype="float32"))
lv15 =
R.call_tir(fused_conv2d8_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4,
(lv14, p_layer4_0_conv1_weight, metadata["relax.expr.Constant"][30],
metadata["relax.expr.Constant"][31], p_layer4_0_bn1_weight,
p_layer4_0_bn1_bias), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
lv16 =
R.call_tir(fused_conv2d10_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11,
(lv14, p_layer4_0_downsample_0_weight, metadata["relax.expr.Constant"][32],
metadata["relax.expr.Constant"][33], p_layer4_0_downsample_1_weight,
p_layer4_0_downsample_1_bias), out_sinfo=R.Tensor((1, 512, 7, 7),
dtype="float32"))
lv17 =
R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_add12_relu4,
(lv15, p_layer4_0_conv2_weight, metadata["relax.expr.Constant"][34],
metadata["relax.expr.Constant"][35], p_layer4_0_bn2_weight,
p_layer4_0_bn2_bias, lv16), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
lv18 =
R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4,
(lv17, p_layer4_1_conv1_weight, metadata["relax.expr.Constant"][36],
metadata["relax.expr.Constant"][37], p_layer4_1_bn1_weight,
p_layer4_1_bn1_bias), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
lv19 =
R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_add12_relu4,
(lv18, p_layer4_1_conv2_weight, metadata["relax.expr.Constant"][38],
metadata["relax.expr.Constant"][39], p_layer4_1_bn2_weight,
p_layer4_1_bn2_bias, lv17), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
lv86 = R.call_tir(adaptive_avg_pool2d, (lv19,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"))
lv87 = R.call_tir(reshape, (lv86,), out_sinfo=R.Tensor((1, 512),
dtype="float32"))
lv88 = R.call_tir(transpose, (p_fc_weight,),
out_sinfo=R.Tensor((512, 1000), dtype="float32"))
lv20 = R.call_tir(fused_matmul_add13, (lv87, lv88, p_fc_bias),
out_sinfo=R.Tensor((1, 1000), dtype="float32"))
gv: R.Tuple(R.Tensor((1, 1000), dtype="float32")) = (lv20,)
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
Traceback (most recent call last):
File "/workspace/e2e_opt_model.py", line 117, in <module>
ex = tvm.compile(mod, target="cuda")
File "/opt/mlc-llm/3rdparty/tvm/python/tvm/driver/build_module.py", line
104, in compile
return tvm.relax.build(
File "/opt/mlc-llm/3rdparty/tvm/python/tvm/relax/vm_build.py", line 263,
in build
return _vmlink(
File "/opt/mlc-llm/3rdparty/tvm/python/tvm/relax/vm_build.py", line 158,
in _vmlink
lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
File "/opt/mlc-llm/3rdparty/tvm/python/tvm/tir/build.py", line 226, in
build
mod = pipeline(mod)
File "/opt/mlc-llm/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in
__call__
return _ffi_transform_api.RunPass(self, mod)
File "python/tvm_ffi/cython/function.pxi", line 678, in
core.Function.__call__
File "<unknown>", line 0, in
tvm::transform::Pass::operator()(tvm::IRModule) const
File "<unknown>", line 0, in
tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext
const&) const
File "<unknown>", line 0, in
tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
File "<unknown>", line 0, in std::_Function_handler<tvm::IRModule
(tvm::IRModule, tvm::transform::PassContext),
tvm::transform::__TVMFFIStaticInitFunc4()::{lambda(tvm::ffi::TypedFunction<tvm::IRModule
(tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>,
tvm::transform::PassInfo)#1}::operator()(tvm::ffi::TypedFunction<tvm::IRModule
(tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>,
tvm::transform::PassInfo) const::{lambda(tvm::IRModule,
tvm::transform::PassContext)#1}>::_M_invoke(std::_Any_data const&,
tvm::IRModule&&, tvm::transform::PassContext&&)
File "<unknown>", line 0, in
tvm::transform::__TVMFFIStaticInitFunc4()::{lambda(tvm::ffi::TypedFunction<tvm::IRModule
(tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>,
tvm::transform::PassInfo)#1}::operator()(tvm::ffi::TypedFunction<tvm::IRModule
(tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>,
tvm::transform::PassInfo) const::{lambda(tvm::IRModule,
tvm::transform::PassContext)#1}::operator()(tvm::IRModule,
tvm::transform::PassContext) const
File "python/tvm_ffi/cython/function.pxi", line 732, in
core.tvm_ffi_callback
File "/opt/mlc-llm/3rdparty/tvm/python/tvm/tir/pipeline.py", line 123, in
_pipeline
mod = tvm.ir.transform.Sequential(passes)(mod)
File "/opt/mlc-llm/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in
__call__
return _ffi_transform_api.RunPass(self, mod)
File "python/tvm_ffi/cython/function.pxi", line 678, in
core.Function.__call__
File "<unknown>", line 0, in
tvm::transform::Pass::operator()(tvm::IRModule) const
File "<unknown>", line 0, in
tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext
const&) const
File "<unknown>", line 0, in
tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
File "<unknown>", line 0, in
tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext
const&) const
File "<unknown>", line 0, in
tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
File "<unknown>", line 0, in std::_Function_handler<tvm::IRModule
(tvm::IRModule, tvm::transform::PassContext),
tvm::tir::transform::VerifyMemory()::{lambda(tvm::IRModule,
tvm::transform::PassContext)#1}>::_M_invoke(std::_Any_data const&,
tvm::IRModule&&, tvm::transform::PassContext&&)
File "<unknown>", line 0, in
tvm::tir::transform::VerifyMemory()::{lambda(tvm::IRModule,
tvm::transform::PassContext)#1}::operator()(tvm::IRModule,
tvm::transform::PassContext) const [clone .constprop.0]
File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::~LogFatal()
[clone .constprop.0]
File "<unknown>", line 0, in
tvm::runtime::detail::LogFatal::Entry::Finalize()
RuntimeError: Memory verification failed with the following errors:
Variable `T_transpose` is directly accessed by host memory (it is not
contained in a thread environment or in the function arguments.
Variable `p_fc_weight` is directly accessed by host memory (it is not
contained in a thread environment or in the function arguments.
Did you forget to bind?
# from tvm.script import tir as T
@T.prim_func
def transpose(p_fc_weight: T.Buffer((T.int64(1000), T.int64(512)),
"float32"), T_transpose: T.Buffer((T.int64(512), T.int64(1000)), "float32")):
T.func_attr({"op_pattern": 2, "target": T.target({"arch": "sm_87",
"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple":
"aarch64-unknown-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind":
"cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}),
"tir.noalias": True})
for ax0, ax1 in T.grid(512, 1000):
T_transpose_1 = T.Buffer((T.int64(512000),), data=T_transpose.data)
p_fc_weight_1 = T.Buffer((T.int64(512000),), data=p_fc_weight.data)
T_transpose_1[ax0 * 1000 + ax1] = p_fc_weight_1[ax1 * 512 + ax0]
```
(the task table was omitted because I lost the full log, but I did remember
that all layers have been tested for 60+ trials, mostly 64, but the transpose
layer was only scheduled for 1 trial).
</details>
TL;DR of the details above:
```
RuntimeError: Memory verification failed with the following errors:
Variable `T_transpose` is directly accessed by host memory (it is not
contained in a thread environment or in the function arguments.
Variable `p_fc_weight` is directly accessed by host memory (it is not
contained in a thread environment or in the function arguments.
Did you forget to bind?
# from tvm.script import tir as T
@T.prim_func
def transpose(p_fc_weight: T.Buffer((T.int64(1000), T.int64(512)),
"float32"), T_transpose: T.Buffer((T.int64(512), T.int64(1000)), "float32")):
T.func_attr({"op_pattern": 2, "target": T.target({"arch": "sm_87",
"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple":
"aarch64-unknown-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind":
"cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}),
"tir.noalias": True})
for ax0, ax1 in T.grid(512, 1000):
T_transpose_1 = T.Buffer((T.int64(512000),), data=T_transpose.data)
p_fc_weight_1 = T.Buffer((T.int64(512000),), data=p_fc_weight.data)
T_transpose_1[ax0 * 1000 + ax1] = p_fc_weight_1[ax1 * 512 + ax0]
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]