Cookiee235 opened a new issue, #17254:
URL: https://github.com/apache/tvm/issues/17254

   I came across the unexpected crash when executing the following scripts. 
This crash can only be triggered by a consequence of multiple transforms, i.e., 
`[ToMixedPrecision, LegalizeOps, AnnotateTIROpPattern, FuseOps, FuseTIR]`. 
Removing any pass will result in the bug being unable to reproduce again. 
   
   ### Actual behavior
   
   ```
   Traceback (most recent call last):
     File "test.py", line 92, in <module>
       mod = relax.transform.FuseTIR()(mod)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in 
__call__
       return _ffi_transform_api.RunPass(self, mod)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 
240, in __call__
       raise_last_ffi_error()
     File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in 
raise_last_ffi_error
       raise py_err
   tvm.error.InternalError: Traceback (most recent call last):
     22: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::transform::Pass, 
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
 tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, 
tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, 
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>, tvm::runtime::TVMRetValue)
     21: tvm::transform::Pass::operator()(tvm::IRModule) const
     20: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     19: tvm::transform::SequentialNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     18: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     17: tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     16: 
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform7FuseTIREvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
     15: tvm::relax::FuseTIR(tvm::IRModule)
     14: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
     13: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&, 
tvm::GlobalVar const&)
     12: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
     11: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode 
const*)
     10: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
     9: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
     8: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock 
const&)
     7: 
tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode 
const*)
     6: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
     5: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode 
const*)
     4: _ZN3tvm5relax11ExprVisitor13VisitBinding_EPKNS0_14Va
     3: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
     2: tvm::relax::RelaxToTIRVarMapCollector::VisitExpr_(tvm::relax::CallNode 
const*)
     1: 
tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode 
const*, tvm::RelayExpr const&, bool)
     0: 
tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode 
const*, tvm::RelayExpr const&, bool)::{lambda(tvm::tir::Buffer, 
tvm::RelayExpr)#1}::operator()(tvm::tir::Buffer, tvm::RelayExpr) const
     File "/software/tvm-lunder/src/relax/transform/fuse_tir.cc", line 442
   InternalError: Check failed: (StructuralEqual()((*it).second, new_buf)) is 
false: Inconsistent buffers compute and lv mapped to the same relax var: lv11
   
   ```
   
   
   ### Steps to reproduce
   
   <details>
   
   <summary>Full test script</summary>
   
   
   ```python
   import tvm
   from tvm import relax
   import numpy as np
   from tvm.script import ir as I
   from tvm.script import tir as T
   from tvm.script import relax as R
   
   
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def conv2d2(data: T.Buffer((T.int64(16), T.int64(32), T.int64(32), 
T.int64(16)), "float16"), weight1: T.Buffer((T.int64(16), T.int64(3), 
T.int64(3), T.int64(16)), "float16"), conv2d_nhwc: T.Buffer((T.int64(16), 
T.int64(32), T.int64(32), T.int64(16)), "float16")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           pad_temp = T.alloc_buffer((T.int64(16), T.int64(34), T.int64(34), 
T.int64(16)), "float16")
           for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(34), T.int64(34), 
T.int64(16)):
               with T.block("pad_temp"):
                   v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
                   T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), 
v_i3])
                   T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                   pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) 
<= v_i1 and v_i1 < T.int64(33) and T.int64(1) <= v_i2 and v_i2 < T.int64(33), 
data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], T.float16(0))
           for nn, yy, xx, ff, ry, rx, rc in T.grid(T.int64(16), T.int64(32), 
T.int64(32), T.int64(16), T.int64(3), T.int64(3), T.int64(16)):
               with T.block("conv2d_nhwc"):
                   v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc = 
T.axis.remap("SSSSRRR", [nn, yy, xx, ff, ry, rx, rc])
                   T.reads(pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc], 
weight1[v_ff, v_ry, v_rx, v_rc])
                   T.writes(conv2d_nhwc[v_nn, v_yy, v_xx, v_ff])
                   with T.init():
                       conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = T.float16(0)
                   conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = conv2d_nhwc[v_nn, 
v_yy, v_xx, v_ff] + pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc] * 
weight1[v_ff, v_ry, v_rx, v_rc]
   
       @T.prim_func(private=True)
       def layer_norm(conv2: T.Buffer((T.int64(16), T.int64(32), T.int64(32), 
T.int64(16)), "float16"), gamma: T.Buffer((T.int64(16),), "float16"), beta: 
T.Buffer((T.int64(16),), "float16"), T_layer_norm: T.Buffer((T.int64(16), 
T.int64(32),T.int64(32), T.int64(16)), "float16")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           conv2_red_temp_v0 = T.alloc_buffer((T.int64(16), T.int64(32), 
T.int64(32)))
           conv2_red_temp_v1 = T.alloc_buffer((T.int64(16), T.int64(32), 
T.int64(32)))
           for ax0, ax1, ax2, k3 in T.grid(T.int64(16), T.int64(32), 
T.int64(32), T.int64(16)):
               with T.block("conv2_red_temp"):
                   v_ax0, v_ax1, v_ax2, v_k3 = T.axis.remap("SSSR", [ax0, ax1, 
ax2, k3])
                   T.reads(conv2[v_ax0, v_ax1, v_ax2, v_k3])
                   T.writes(conv2_red_temp_v0[v_ax0, v_ax1, v_ax2], 
conv2_red_temp_v1[v_ax0, v_ax1, v_ax2])
                   with T.init():
                       conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0)
                       conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0)
                   v_conv2_red_temp_v0: T.float32 = conv2_red_temp_v0[v_ax0, 
v_ax1, v_ax2] + T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3])
                   v_conv2_red_temp_v1: T.float32 = conv2_red_temp_v1[v_ax0, 
v_ax1, v_ax2] + T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3]) * 
T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3])
                   conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_conv2_red_temp_v0
                   conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_conv2_red_temp_v1
           for ax0, ax1, ax2, ax3 in T.grid(T.int64(16), T.int64(32), 
T.int64(32), T.int64(16)):
               with T.block("T_layer_norm"):
                   v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
                   T.reads(conv2[v_ax0, v_ax1, v_ax2, v_ax3], 
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2], conv2_red_temp_v1[v_ax0, v_ax1, v_ax2], 
gamma[v_ax3], beta[v_ax3])
                   T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
                   T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", 
(T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_ax3]) - 
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) * 
T.rsqrt(conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) - 
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) * 
(conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) + 
T.float32(1.0000000000000001e-05))) * gamma[v_ax3] + beta[v_ax3]
   
       @T.prim_func(private=True)
       def relu(lv: T.Buffer((T.int64(16), T.int64(32), T.int64(32), 
T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(32), 
T.int64(32), T.int64(16)), "float16")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(32), T.int64(32), 
T.int64(16)):
               with T.block("compute"):
                   v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
                   T.reads(lv[v_i0, v_i1, v_i2, v_i3])
                   T.writes(compute[v_i0, v_i1, v_i2, v_i3])
                   compute[v_i0, v_i1, v_i2, v_i3] = T.max(lv[v_i0, v_i1, v_i2, 
v_i3], T.float16(0))
   
       @R.function
       def main(data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight1: 
R.Tensor((16, 3, 3, 16), dtype="float16"), weight2: R.Tensor((16, 3, 3, 16), 
dtype="float16"), weight3: R.Tensor((16, 3, 3, 16), dtype="float16"), gamma: 
R.Tensor((16,), dtype="float16"), beta: R.Tensor((16,), dtype="float16")) -> 
R.Tensor((16, 32, 32, 16), dtype="float16"):
           R.func_attr({"num_input": 1})
           cls = Module
           with R.dataflow():
               lv = R.call_tir(cls.conv2d2, (data, weight1), 
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
               conv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((16, 32, 
32, 16), dtype="float16"))
               lv2 = R.call_tir(cls.conv2d2, (conv1, weight2), 
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
               conv2 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((16, 32, 
32, 16), dtype="float16"))
               ln = R.call_tir(cls.layer_norm, (conv2, gamma, beta), 
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
               lv5 = R.call_tir(cls.conv2d2, (ln, weight3), 
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
               conv3 = R.call_tir(cls.relu, (lv5,), out_sinfo=R.Tensor((16, 32, 
32, 16), dtype="float16"))
               R.output(conv3)
           return conv3
   
   
   mod = Module
   
   mod = relax.transform.ToMixedPrecision()(mod)
   mod = relax.transform.LegalizeOps()(mod)
   mod = relax.transform.AnnotateTIROpPattern()(mod)
   mod = relax.transform.FuseOps()(mod)
   mod = relax.transform.FuseTIR()(mod)
   ```
   
   </details>
   
   
   cc @Lunderberg @junrushao 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to