honghuichao commented on issue #13759:
URL: https://github.com/apache/tvm/issues/13759#issuecomment-1379733000
this is my demo extract from my network. @masahi
```
import torch
from pdb import set_trace as d
import onnx
from tvm import relay
class TestIndexPut(torch.nn.Module):
def __init__(self):
super(TestIndexPut, self).__init__()
def
forward(self,batch_src_corr_points,indices0,indices1,src_corr_points):
batch_src_corr_points.index_put_([indices0, indices1],
src_corr_points)
return batch_src_corr_points
model = TestIndexPut()
dummy_batch_src_corr_points = torch.rand(136888, 3).float()
dummy_indices0 = torch.rand(50168,3).long()
dummy_indices1 = torch.rand(50168,3).long()
dummy_src_corr_points = torch.rand(50168,3).float()
dummy_output =
model(dummy_batch_src_corr_points,dummy_indices0,dummy_indices1,dummy_src_corr_points)
input_names=["dummy_batch_src_corr_points","dummy_indices0","dummy_indices1","dummy_src_corr_points"]
output_names=["batch_src_corr_points"]
dynamic_axes = {
"dummy_batch_src_corr_points":{0:"bscp_shape0",1:"bscp_shape1"},
"dummy_indices0":{0:"i0_shape0",1:"i0_shape1"},
"dummy_indices1":{0:"i1_shape0",1:"i1_shape1"},
"dummy_src_corr_points":{0:"o_shape0",1:"o_shape1"}
}
with torch.no_grad():
torch.onnx.export(model.eval(),
(dummy_batch_src_corr_points,dummy_indices0,dummy_indices1,dummy_src_corr_points),
'TestIndexPut.onnx',
input_names =
input_names,output_names=output_names,opset_version=11,dynamic_axes=dynamic_axes,operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
m = onnx.load('TestIndexPut.onnx')
relay.frontend.onnx.from_onnx(m)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]