This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f21a17b67c [Pytorch] frontend full_impl fix (#14122)
f21a17b67c is described below

commit f21a17b67cfd1efb40dbb2271ebca28e7a7c450f
Author: Jian Sheng <[email protected]>
AuthorDate: Sat Feb 25 20:16:29 2023 -0800

    [Pytorch] frontend full_impl fix (#14122)
    
    Minor fix in pytorch frontend to compile gpt2 model, reproduce script.
    torch_version = 1.13.1
    transformers_version = 4.26.1
    
    ```
    from transformers import GPT2LMHeadModel
    import torch
    import tvm
    from tvm import relay
    
    inp = torch.ones((1, 128)).to(torch.int64)
    input_shapes = [("input_ids", ((1, 128), "int64"))]
    
    model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=False)
    trace_model = torch.jit.trace(model, inp, strict=False)
    outputs = trace_model(inp)
    
    mod, params = relay.frontend.from_pytorch(trace_model, input_shapes)
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target='llvm', params=params)
    
    runtime = 
tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.device('cpu', 0)))
    runtime.set_input("input_ids", inp.numpy())
    runtime.run()
    out = runtime.get_output(0).numpy()
    print(out)
    print('Done...')
    ```
    
    Before the fix, the error message
    ```
    Traceback (most recent call last):
      File "gpt2_compile.py", line 13, in <module>
        mod, params = relay.frontend.from_pytorch(trace_model, input_shapes)
      File "/home/ubuntu/apache_tvm/tvm/python/tvm/relay/frontend/pytorch.py", 
line 4791, in from_pytorch
        outputs = 
converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, 
ret_name)
      File "/home/ubuntu/apache_tvm/tvm/python/tvm/relay/frontend/pytorch.py", 
line 4164, in convert_operators
        relay_out = relay_op(
      File "/home/ubuntu/apache_tvm/tvm/python/tvm/relay/frontend/pytorch.py", 
line 841, in full
        return self.full_impl(data, fill_value, dtype)
      File "/home/ubuntu/apache_tvm/tvm/python/tvm/relay/frontend/pytorch.py", 
line 743, in full_impl
        fill_value = _expr.const(fill_value, dtype=dtype)
      File "/home/ubuntu/apache_tvm/tvm/python/tvm/relay/expr.py", line 707, in 
const
        raise ValueError("value has to be scalar or NDArray")
    ValueError: value has to be scalar or NDArray
    ```
    
    because `fill_value` is
    ```
    %0 = cast(64, dtype="float32");
    power(%0, 0.5f)
    ```
---
 python/tvm/relay/frontend/pytorch.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 57997bb894..0dc9ffef6f 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -738,6 +738,8 @@ class PyTorchOpConverter:
             size = _op.concatenate(tmp, axis=0)
 
         if not isinstance(fill_value, _expr.Constant):
+            if isinstance(fill_value, _expr.Expr):
+                fill_value = _infer_value(fill_value, {})
             fill_value = _expr.const(fill_value, dtype=dtype)
         out = _op.full(fill_value, size, dtype=dtype)
         if need_reshape:

Reply via email to