Cookiee235 opened a new issue, #17255:
URL: https://github.com/apache/tvm/issues/17255
### Actual behavior
```
Traceback (most recent call last):
File "test.py", line 88, in <module>
ex = relax.build(mod, target='llvm')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/relax/vm_build.py", line 335, in
build
mod = pipeline(mod)
^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in
__call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line
240, in __call__
raise_last_ffi_error()
File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in
raise_last_ffi_error
raise py_err
File "/software/tvm-lunder/python/tvm/relax/pipeline.py", line 101, in
_pipeline
mod = seq(mod)
^^^^^^^^
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in
__call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line
240, in __call__
raise_last_ffi_error()
File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in
raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
10:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::transform::Pass,
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>,
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>
>, tvm::runtime::TVMRetValue)
9: tvm::transform::Pass::operator()(tvm::IRModule) const
8: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
7: tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
6: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
5: tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
4:
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform21StaticPlanBlockMemoryEvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
3: tvm::relax::StaticPlanBlockMemory(tvm::IRModule)
2: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
1: tvm::relax::StorageAllocatorInit::VisitExpr_(tvm::relax::FunctionNode
const*)
0: tvm::relax::SetTIRVarUpperBound(tvm::relax::Function,
tvm::arith::Analyzer*, tvm::runtime::Map<tvm::tir::Var, tvm::arith::IntSet,
void, void>*)
File
"/software/tvm-lunder/src/relax/transform/static_plan_block_memory.cc", line 378
TVMError: Check failed: (value != nullptr) is false: The entry value of attr
`tir_var_upper_bound` should be integer. However Array is got.
```
### Steps to reproduce
```python
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(8),), "float32"),
rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer((T.int64(8),),
"float32")):
T.evaluate(0)
@T.prim_func
def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
T.evaluate(0)
@T.prim_func
def log(rxplaceholder: T.Buffer((T.int64(10),), "float32"), compute:
T.Buffer((T.int64(10),), "float32")):
T.evaluate(0)
@T.prim_func
def pad(rxplaceholder: T.Buffer((T.int64(8),), "float32"), PadInput:
T.Buffer((T.int64(10),), "float32")):
T.evaluate(0)
@T.prim_func
def relu(rxplaceholder: T.Buffer((T.int64(8),), "float32"), compute:
T.Buffer((T.int64(8),), "float32")):
T.evaluate(0)
@T.prim_func
def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)),
"float32"), T_reshape: T.Buffer((T.int64(8),), "float32")):
T.evaluate(0)
@T.prim_func(private=True)
def reshape1(alloc: T.Buffer((T.int64(2), T.int64(4)), "float32"),
T_reshape: T.Buffer((T.int64(8),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(8)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(8), ax0)
T.reads(alloc[v_ax0 % T.int64(8) // T.int64(4), v_ax0 %
T.int64(4)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = alloc[v_ax0 % T.int64(8) // T.int64(4),
v_ax0 % T.int64(4)]
@R.function
def main_2(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
R.func_attr({"relax.force_pure": 1})
cls = Module
storage: R.Object = R.memory.alloc_storage(R.shape([32]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
cls.exp(x, alloc)
lv1 = R.call_tir(cls.reshape1, (alloc,), out_sinfo=R.Tensor((8,),
dtype="float32"))
R.memory.kill_tensor(alloc)
storage1: R.Object = R.memory.alloc_storage(R.shape([40]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc1: R.Tensor((8,), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([8]),
R.dtype("float32"))
cls.relu(lv1, alloc1)
R.memory.kill_tensor(lv1)
alloc2: R.Tensor((8,), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([8]),
R.dtype("float32"))
R.memory.kill_storage(storage)
cls.add(alloc1, R.const(1, "float32"), alloc2)
R.memory.kill_tensor(alloc1)
alloc3: R.Tensor((10,), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([10]),
R.dtype("float32"))
R.memory.kill_storage(storage1)
cls.pad(alloc2, alloc3)
R.memory.kill_tensor(alloc2)
storage_1: R.Object = R.memory.alloc_storage(R.shape([40]),
R.prim_value(0), R.str("global"), R.dtype("uint8"))
alloc4: R.Tensor((10,), dtype="float32") =
R.memory.alloc_tensor(storage_1, R.prim_value(0), R.shape([10]),
R.dtype("float32"))
R.memory.kill_storage(storage_1)
cls.log(alloc3, alloc4)
R.memory.kill_tensor(alloc3)
return alloc4
@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
n = T.int64()
R.func_attr({"relax.force_pure": 1, "tir_var_upper_bound": {"n":
[4]}})
cls = Module
para0: R.Tensor((2, n), dtype="float32") = x
res: R.Tensor((10,), dtype="float32") = cls.main_2(para0)
return res
mod = Module
ex = relax.build(mod, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())
```
cc @Lunderberg @junrushao
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]