zyc-bit opened a new issue, #15795: URL: https://github.com/apache/tvm/issues/15795
### Expected behavior Successfully converted a PyTorch traced model to TVM's Relay IR. ### Actual behavior Error happened when converting `index_put_()`. In pytorch model forward python code: `far[rays_d[:, 2] >= 0] = self.near_far[-1]` where the shapes of `far`, `rays_d` and self.near_far[-1] is as follows:  and the error is ``` Traceback (most recent call last): File "from_torch_test.py", line 107, in <module> mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/frontend/pytorch.py", line 5038, in from_pytorch outputs = converter.convert_operators(operator_nodes, outputs, ret_name) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/frontend/pytorch.py", line 4291, in convert_operators relay_out = relay_op( File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/frontend/pytorch.py", line 2708, in index_put return _op.scatter_nd(in_tensor, index_tensor, values, mode) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/op/transform.py", line 422, in scatter_nd return _make.scatter_nd(data, indices, updates, mode) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__ raise_last_ffi_error() File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error raise py_err tvm._ffi.base.TVMError: Traceback (most recent call last): 2: _ZN3tvm7runtime13PackedFun 1: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::String)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::String)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::String), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const 0: tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::RelayExpr<tvm::RelayExpr>() const 3: _ZN3tvm7runtime13PackedFun 2: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::String)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::String)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::String), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const 1: tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::RelayExpr<tvm::RelayExpr>() const 0: tvm::RelayExpr tvm::runtime::TVMPODValue_::AsObjectRef<tvm::RelayExpr>() const File "/cpfs01/user/zhangyuchang/projects/tvm/include/tvm/runtime/packed_func.h", line 779 TVMError: In function relay.op._make.scatter_nd(0: RelayExpr, 1: RelayExpr, 2: RelayExpr, 3: runtime.String) -> RelayExpr: error while converting argument 2: [02:51:56] /cpfs01/user/zhangyuchang/projects/tvm/include/tvm/runtime/packed_func.h:1977: InternalError: Check failed: type_code_ == kTVMObjectHandle (2 vs. 8) : expected Object but got float ``` what's more, when the `index_put_()` operation faces different shape of tensors, such as: `sigma[ray_valid] = validsigma` where  `ray_valid` here is a 2D True and False tensor, and the element number of `validsigma` is the same as the number of `True` in `ray_valid`. and error for this case is here: ``` Traceback (most recent call last): File "from_torch_test.py", line 107, in <module> mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/frontend/pytorch.py", line 5038, in from_pytorch outputs = converter.convert_operators(operator_nodes, outputs, ret_name) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/frontend/pytorch.py", line 4298, in convert_operators self.record_output_type(relay_out) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/frontend/pytorch.py", line 238, in record_output_type self.infer_type_with_prelude(output) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/frontend/pytorch.py", line 174, in infer_type_with_prelude body = self.infer_type(val, self.prelude.mod) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/relay/frontend/pytorch.py", line 167, in infer_type new_mod = transform.InferType()(new_mod) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/ir/transform.py", line 160, in __call__ return _ffi_transform_api.RunPass(self, mod) File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__ raise_last_ffi_error() File "/cpfs01/user/zhangyuchang/projects/tvm/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error raise py_err tvm.error.InternalError: Traceback (most recent call last): 9: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 8: tvm::transform::Pass::operator()(tvm::IRModule) const 7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 6: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 4: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function) 3: tvm::relay::TypeSolver::Solve() 2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 1: tvm::relay::ScatterNDRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) 0: tvm::runtime::Array<tvm::PrimExpr, void>::operator[](long) const File "/cpfs01/user/zhangyuchang/projects/tvm/include/tvm/runtime/container/array.h", line 414 InternalError: Check failed: (0 <= i && i < p->size_) is false: IndexError: indexing 1 on an array of size 1 ``` ### Environment Operating System: Ubuntu 18.04 TVM version: 0.14.dev0 pytorch version: 2.1.0 python version: 3.8.16 ### Steps to reproduce you can write a simple test.py with a simple pytorch model forward. The forward have such as `A[B] = C` operation. You can change the shape of A, B and C to suit the two error case I mention above. And then trace the model with torch.jit.trace, convert the traced model in tvm, the error will happen. I give a example below: ``` import torch.nn as nn import torch class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.y = torch.tensor([[0,1,0,1,1,1],[0,1,0,1,1,1]]) self.z = torch.tensor([2., 3., 4., 5., 6., 7., 8., 9.]) def forward(self, x): x[self.y > 0] = self.z return x net = Net() input_shape = (2, 6) input_data = torch.zeros(input_shape) out = net(input_data) traced_net = torch.jit.trace(net, input_data) traced_net.save("tobetest09201.pt") ``` Change the shape of `self.y` and `self.z` to suit the two case I mentioned above, and then ``` scripted_model = torch.jit.load("you traced model path").eval() mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) # shape list depends on you set before ``` ### Triage * frontend:pytorch -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
