JamesTheZ opened a new issue, #13664:
URL: https://github.com/apache/tvm/issues/13664

   The latest TVM fail to convert the standard huggingface BERT model.
   The script:
   
   ``` python
   from transformers import BertTokenizer, BertConfig, BertModel
   import tvm
   from tvm import relay
   import torch
   
   
   class AutoModelAMP(BertModel):
       def __init__(self, *args, **keyargs):
           with torch.cuda.amp.autocast():
               super().__init__(*args, **keyargs)
   
       def forward(self, *args, **keyargs):
           with torch.cuda.amp.autocast():
               return super().forward(*args, **keyargs)
   
   
   config = BertConfig.from_pretrained("bert-large-uncased")
   config.return_dict = False
   config.torchscript = True
   model = AutoModelAMP(config).cuda().eval()
   
   batch = 1
   seq_len = 128
   input = torch.zeros([batch, seq_len], dtype=torch.int).long()
   inputs = {
       "input_ids": input.cuda(),
       "attention_mask": input.cuda(),
       "token_type_ids": input.cuda(),
   }
   input_list = []
   for k, v in inputs.items():
       input_list.append(v)
   
   shape_list = []
   for k, v in inputs.items():
       shape_list.append((k, v.shape))
   
   traced_model = torch.jit.trace(
       model.cuda(), input_list, strict=False).cuda().eval()
   mod, params = relay.frontend.from_pytorch(traced_model, shape_list)
   ```
   Error log:
   ```
   Traceback (most recent call last):
     File "tvm_bug.py", line 40, in <module>
       mod, params = relay.frontend.from_pytorch(traced_model, shape_list)
     File "/home/workspace/release/tvm/python/tvm/relay/frontend/pytorch.py", 
line 4653, in from_pytorch
       outputs = 
converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, 
ret_name)
     File "/home/workspace/release/tvm/python/tvm/relay/frontend/pytorch.py", 
line 4026, in convert_operators
       relay_out = relay_op(
     File "/home/workspace/release/tvm/python/tvm/relay/frontend/pytorch.py", 
line 1704, in linear
       mm_out = self.matmul(
     File "/home/workspace/release/tvm/python/tvm/relay/frontend/pytorch.py", 
line 1932, in matmul
       out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
     File "/home/workspace/release/tvm/python/tvm/relay/op/transform.py", line 
325, in reshape
       tempshape.append(int(shape))
   TypeError: int() argument must be a string, a bytes-like object or a number, 
not 'Any'
   ```
   The `a_shape` and `b_shape` at 
`tvm/python/tvm/relay/frontend/pytorch.py:1932` are `(?, 128, 1024)` and 
`(1024, 1024)`.
   It looks like the reshape of the output of matmul does not support dynamic 
shape properly.
   
   #### Env:
   TVM commit: e2680142ef8d301ae5dfc10b89594fd388219a7e
   huggingface transformers version: 4.25.1
   torch: 1.12.0+cu113


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to