guberti opened a new issue, #13010:
URL: https://github.com/apache/tvm/issues/13010
Often times, it is useful to have reduction axes with length 1 - e.g. to
deal with a 1x1 kernel for a `conv2d`. However, when using `tensorize` in this
case where the **outer** reduction axis has length one, you get an error like
the following:
```
TVMError: Traceback (most recent call last):
36: TVMFuncCall
35: _ZN3tvm7runtime13PackedFun
34: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::te::Schedule,
tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String
const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void>
const&, bool)>::AssignTypedLambda<tvm::__mk_TVM16::{lambda(tvm::te::Schedule,
tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String
const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void>
const&, bool)#1}>(tvm::__mk_TVM16::{lambda(tvm::te::Schedule,
tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String
const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void>
const&, bool)#1}, std::string)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs,
tvm::runtime::TVMArgs const&) const
33: tvm::LowerSchedule(tvm::te::Schedule,
tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::string const&,
std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer,
std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>,
std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&,
bool)
32: tvm::LowerWithPassList(tvm::IRModule,
tvm::runtime::Array<tvm::transform::Pass, void>)
31: tvm::transform::Pass::operator()(tvm::IRModule) const
30: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
29: tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
28: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
27: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
26:
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_9transform14StorageFlattenEibEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
25: tvm::tir::StorageFlatten(tvm::tir::PrimFunc, int, bool)
24: tvm::transform::Pass::operator()(tvm::IRModule) const
23: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
22: tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
21: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
20: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
19: _ZN3tvm7runtime13PackedFun
18: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc,
tvm::IRModule,
tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::BufferBindUnwrapper::Pass()::{lambda(tvm::tir::PrimFunc,
tvm::IRModule,
tvm::transform::PassContext)#1}>(tvm::tir::BufferBindUnwrapper::Pass()::{lambda(tvm::tir::PrimFunc,
tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs
const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const,
tvm::runtime::TVMRetValue) const
17: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
16: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt
const&)>::VisitStmt(tvm::tir::Stmt const&)
15: _ZZN3tvm3tir11StmtFunctorI
14: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::ForNode const*)
13: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
12: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt
const&)>::VisitStmt(tvm::tir::Stmt const&)
11: _ZZN3tvm3tir11StmtFunctorI
10: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::ForNode const*)
9: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
8: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt
const&)>::VisitStmt(tvm::tir::Stmt const&)
7: _ZZN3tvm3tir11StmtFunctorI
6: tvm::tir::BufferBindUnwrapper::VisitStmt_(tvm::tir::AttrStmtNode const*)
5:
tvm::tir::BufferBindUnwrapper::HandleBufferBindScope(tvm::tir::AttrStmtNode
const*)
4: tvm::tir::ArgBinder::BindBuffer(tvm::tir::Buffer const&,
tvm::tir::Buffer const&, std::string const&, bool)
3: tvm::tir::ArgBinder::BindArray(tvm::runtime::Array<tvm::PrimExpr, void>
const&, tvm::runtime::Array<tvm::PrimExpr, void> const&, std::string const&)
2: tvm::tir::ArgBinder::Bind_(tvm::PrimExpr const&, tvm::PrimExpr const&,
std::string const&, bool)
1: tvm::tir::BinderAddAssert(tvm::arith::Analyzer*, tvm::PrimExpr,
std::string const&, std::vector<tvm::tir::Stmt, std::allocator<tvm::tir::Stmt>
>*)
0: _ZN3tvm7runtime6deta
File "/workspace/tvm/src/tir/transforms/arg_binder.cc", line 40
TVMError: Bind have an unmet assertion: (bool)0, on argument
foobar.strides[1]
```
Note that this bug only occurs when the outer reduction axis has length one
- all the others are fine. The issue occurs on the latest version of TVM, and
on older versions like `0.9.0`.
I've found a hack to work around this bug, and have used it in some of my
PRs:
https://github.com/apache/tvm/blob/981b1bdb4780b27cc673ba9cfcd29bb56322ba54/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py#L158-L162
The hack is pretty gross though, so it would be nice to get a fix.
### Steps to reproduce
The easiest way to reproduce this is to use a Colab notebook:
https://colab.research.google.com/drive/1Y9LXBdQxQD-FjNbW6cChuExse6IvMdGZ
You can also reproduce it using this script:
[bug_reproduction.py.txt](https://github.com/apache/tvm/files/9734351/bug_reproduction.py.txt)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]