Cookiee235 opened a new issue, #17254:
URL: https://github.com/apache/tvm/issues/17254
I came across the unexpected crash when executing the following scripts.
This crash can only be triggered by a consequence of multiple transforms, i.e.,
`[ToMixedPrecision, LegalizeOps, AnnotateTIROpPattern, FuseOps, FuseTIR]`.
Removing any pass will result in the bug being unable to reproduce again.
### Actual behavior
```
Traceback (most recent call last):
File "test.py", line 92, in <module>
mod = relax.transform.FuseTIR()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in
__call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line
240, in __call__
raise_last_ffi_error()
File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in
raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
22:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::transform::Pass,
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>,
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>
>, tvm::runtime::TVMRetValue)
21: tvm::transform::Pass::operator()(tvm::IRModule) const
20: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
19: tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
18: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
17: tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
16:
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform7FuseTIREvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
15: tvm::relax::FuseTIR(tvm::IRModule)
14: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
13: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&,
tvm::GlobalVar const&)
12: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
11: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode
const*)
10: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
9: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
8: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock
const&)
7:
tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode
const*)
6: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
5: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode
const*)
4: _ZN3tvm5relax11ExprVisitor13VisitBinding_EPKNS0_14Va
3: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
2: tvm::relax::RelaxToTIRVarMapCollector::VisitExpr_(tvm::relax::CallNode
const*)
1:
tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode
const*, tvm::RelayExpr const&, bool)
0:
tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode
const*, tvm::RelayExpr const&, bool)::{lambda(tvm::tir::Buffer,
tvm::RelayExpr)#1}::operator()(tvm::tir::Buffer, tvm::RelayExpr) const
File "/software/tvm-lunder/src/relax/transform/fuse_tir.cc", line 442
InternalError: Check failed: (StructuralEqual()((*it).second, new_buf)) is
false: Inconsistent buffers compute and lv mapped to the same relax var: lv11
```
### Steps to reproduce
<details>
<summary>Full test script</summary>
```python
import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def conv2d2(data: T.Buffer((T.int64(16), T.int64(32), T.int64(32),
T.int64(16)), "float16"), weight1: T.Buffer((T.int64(16), T.int64(3),
T.int64(3), T.int64(16)), "float16"), conv2d_nhwc: T.Buffer((T.int64(16),
T.int64(32), T.int64(32), T.int64(16)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(16), T.int64(34), T.int64(34),
T.int64(16)), "float16")
for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(34), T.int64(34),
T.int64(16)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1),
v_i3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1)
<= v_i1 and v_i1 < T.int64(33) and T.int64(1) <= v_i2 and v_i2 < T.int64(33),
data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], T.float16(0))
for nn, yy, xx, ff, ry, rx, rc in T.grid(T.int64(16), T.int64(32),
T.int64(32), T.int64(16), T.int64(3), T.int64(3), T.int64(16)):
with T.block("conv2d_nhwc"):
v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc =
T.axis.remap("SSSSRRR", [nn, yy, xx, ff, ry, rx, rc])
T.reads(pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc],
weight1[v_ff, v_ry, v_rx, v_rc])
T.writes(conv2d_nhwc[v_nn, v_yy, v_xx, v_ff])
with T.init():
conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = T.float16(0)
conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = conv2d_nhwc[v_nn,
v_yy, v_xx, v_ff] + pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc] *
weight1[v_ff, v_ry, v_rx, v_rc]
@T.prim_func(private=True)
def layer_norm(conv2: T.Buffer((T.int64(16), T.int64(32), T.int64(32),
T.int64(16)), "float16"), gamma: T.Buffer((T.int64(16),), "float16"), beta:
T.Buffer((T.int64(16),), "float16"), T_layer_norm: T.Buffer((T.int64(16),
T.int64(32),T.int64(32), T.int64(16)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
conv2_red_temp_v0 = T.alloc_buffer((T.int64(16), T.int64(32),
T.int64(32)))
conv2_red_temp_v1 = T.alloc_buffer((T.int64(16), T.int64(32),
T.int64(32)))
for ax0, ax1, ax2, k3 in T.grid(T.int64(16), T.int64(32),
T.int64(32), T.int64(16)):
with T.block("conv2_red_temp"):
v_ax0, v_ax1, v_ax2, v_k3 = T.axis.remap("SSSR", [ax0, ax1,
ax2, k3])
T.reads(conv2[v_ax0, v_ax1, v_ax2, v_k3])
T.writes(conv2_red_temp_v0[v_ax0, v_ax1, v_ax2],
conv2_red_temp_v1[v_ax0, v_ax1, v_ax2])
with T.init():
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0)
conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0)
v_conv2_red_temp_v0: T.float32 = conv2_red_temp_v0[v_ax0,
v_ax1, v_ax2] + T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3])
v_conv2_red_temp_v1: T.float32 = conv2_red_temp_v1[v_ax0,
v_ax1, v_ax2] + T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3]) *
T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3])
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_conv2_red_temp_v0
conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_conv2_red_temp_v1
for ax0, ax1, ax2, ax3 in T.grid(T.int64(16), T.int64(32),
T.int64(32), T.int64(16)):
with T.block("T_layer_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
T.reads(conv2[v_ax0, v_ax1, v_ax2, v_ax3],
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2], conv2_red_temp_v1[v_ax0, v_ax1, v_ax2],
gamma[v_ax3], beta[v_ax3])
T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16",
(T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_ax3]) -
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) *
T.rsqrt(conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) -
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) *
(conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) +
T.float32(1.0000000000000001e-05))) * gamma[v_ax3] + beta[v_ax3]
@T.prim_func(private=True)
def relu(lv: T.Buffer((T.int64(16), T.int64(32), T.int64(32),
T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(32),
T.int64(32), T.int64(16)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(32), T.int64(32),
T.int64(16)):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
T.reads(lv[v_i0, v_i1, v_i2, v_i3])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.max(lv[v_i0, v_i1, v_i2,
v_i3], T.float16(0))
@R.function
def main(data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight1:
R.Tensor((16, 3, 3, 16), dtype="float16"), weight2: R.Tensor((16, 3, 3, 16),
dtype="float16"), weight3: R.Tensor((16, 3, 3, 16), dtype="float16"), gamma:
R.Tensor((16,), dtype="float16"), beta: R.Tensor((16,), dtype="float16")) ->
R.Tensor((16, 32, 32, 16), dtype="float16"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.conv2d2, (data, weight1),
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
conv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((16, 32,
32, 16), dtype="float16"))
lv2 = R.call_tir(cls.conv2d2, (conv1, weight2),
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
conv2 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((16, 32,
32, 16), dtype="float16"))
ln = R.call_tir(cls.layer_norm, (conv2, gamma, beta),
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
lv5 = R.call_tir(cls.conv2d2, (ln, weight3),
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
conv3 = R.call_tir(cls.relu, (lv5,), out_sinfo=R.Tensor((16, 32,
32, 16), dtype="float16"))
R.output(conv3)
return conv3
mod = Module
mod = relax.transform.ToMixedPrecision()(mod)
mod = relax.transform.LegalizeOps()(mod)
mod = relax.transform.AnnotateTIROpPattern()(mod)
mod = relax.transform.FuseOps()(mod)
mod = relax.transform.FuseTIR()(mod)
```
</details>
cc @Lunderberg @junrushao
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]