huochaitiantang opened a new pull request #10721:
URL: https://github.com/apache/tvm/pull/10721


   Thanks for contributing to TVM!   Please refer to guideline 
https://tvm.apache.org/docs/contribute/ for useful information and tips. After 
the pull request is submitted, please request code reviews from 
[Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers)
 by @ them in the pull request thread.
   
   When I use Realy VM and the newshape of Reshape is [1, -1, 0] (the -1 is 
behind 0), the output shape is wrong. The test code is as follows:
   ```
   import tvm 
   from tvm import relay
   import numpy as np
   
   target = "llvm"
   dtype = "float32"
   executor_kind = "vm"
   dev = tvm.runtime.ndarray.device(str(target), 0)
   
   def build_module(x_shape, new_shape):
       x = relay.var("x", shape=x_shape, dtype=dtype)
       y = relay.reshape(x, newshape=new_shape)
       func = relay.Function([x], y)
       mod = tvm.ir.IRModule.from_expr(func)
       print("mod:", mod)
       executable = relay.vm.compile(mod, target)
       vm = tvm.runtime.vm.VirtualMachine(executable, dev)
       return vm
   
   def run(vm, x_np_shape):
       x_np = np.random.uniform(size=x_np_shape).astype('float32')
       z = vm.run(x_np)
       print("tvm input:", x_np.shape)
       print("tvm output:", z.numpy().shape)
       print("")
   
   def test(new_shape):
       vm = build_module([relay.Any(), relay.Any(), relay.Any()], 
new_shape=new_shape)
       run(vm, (2, 3, 4)) 
   
   if __name__ == "__main__":
       test(new_shape=[0, 1, -1]) # correct
       test(new_shape=[0, -1, 1]) # correct
       test(new_shape=[1, -1, 0]) # wrong
       test(new_shape=[-1, 1, 0]) # wrong
   ```
   The output is
   ```
   mod: def @main(%x: Tensor[(?, ?, ?), float32]) {
     reshape(%x, newshape=[0, 1, -1])
   }
   
   tvm input: (2, 3, 4)
   tvm output: (2, 1, 12)
   
   mod: def @main(%x: Tensor[(?, ?, ?), float32]) {
     reshape(%x, newshape=[0, -1, 1])
   }
   
   tvm input: (2, 3, 4)
   tvm output: (2, 12, 1)
   
   mod: def @main(%x: Tensor[(?, ?, ?), float32]) {
     reshape(%x, newshape=[1, -1, 0])
   }
   
   tvm input: (2, 3, 4)
   tvm output: (1, 8, 3)
   
   mod: def @main(%x: Tensor[(?, ?, ?), float32]) {
     reshape(%x, newshape=[-1, 1, 0])
   }
   
   tvm input: (2, 3, 4)
   tvm output: (8, 1, 3)
   ```
   The problem is derived from 
https://github.com/apache/tvm/blob/main/python/tvm/relay/op/_transform.py#L420, 
where only `dst_idx` increases but `src_idx` does not. After fix it, the output 
shape is correct:
   ```
   mod: def @main(%x: Tensor[(?, ?, ?), float32]) {
     reshape(%x, newshape=[0, 1, -1])
   }
   
   tvm input: (2, 3, 4)
   tvm output: (2, 1, 12)
   
   mod: def @main(%x: Tensor[(?, ?, ?), float32]) {
     reshape(%x, newshape=[0, -1, 1])
   }
   
   tvm input: (2, 3, 4)
   tvm output: (2, 12, 1)
   
   mod: def @main(%x: Tensor[(?, ?, ?), float32]) {
     reshape(%x, newshape=[1, -1, 0])
   }
   
   tvm input: (2, 3, 4)
   tvm output: (1, 6, 4)
   
   mod: def @main(%x: Tensor[(?, ?, ?), float32]) {
     reshape(%x, newshape=[-1, 1, 0])
   }
   
   tvm input: (2, 3, 4)
   tvm output: (6, 1, 4)
   ```
   
   Thanks for your review! @tqchen @junrushao1994 @icemelon 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to