This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0b13b5c844 [Unity] Enhance Torch-consistency in rehsape (#16360)
0b13b5c844 is described below

commit 0b13b5c8445dfae5718af13faafef942d0a607fa
Author: Junru Shao <[email protected]>
AuthorDate: Sun Jan 7 08:42:38 2024 -0800

    [Unity] Enhance Torch-consistency in rehsape (#16360)
    
    This PR introduces the following signature changes:
    
    - `Tensor.reshape(shape)` to `Tensor.reshape(*shape)`
    - `Tensor.permute_dims(axes)` to `Tensor.permute_dims(*axes)`
---
 python/tvm/relax/frontend/nn/_tensor_op.py    |  4 ++--
 tests/python/relax/test_frontend_nn_tensor.py | 10 +++++-----
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py 
b/python/tvm/relax/frontend/nn/_tensor_op.py
index 627b8b626c..3a646e29b8 100644
--- a/python/tvm/relax/frontend/nn/_tensor_op.py
+++ b/python/tvm/relax/frontend/nn/_tensor_op.py
@@ -78,10 +78,10 @@ class _TensorOp:
         other = _convert_scalar(other, self)
         return _op().minimum(self, other)
 
-    def reshape(self, shape):
+    def reshape(self, *shape):
         return _op().reshape(self, shape)
 
-    def permute_dims(self, axes):
+    def permute_dims(self, *axes):
         return _op().permute_dims(self, axes)
 
     def repeat(self, repeats: int, axis: Optional[int] = None):
diff --git a/tests/python/relax/test_frontend_nn_tensor.py 
b/tests/python/relax/test_frontend_nn_tensor.py
index 2d2a23cc46..bd96626c95 100644
--- a/tests/python/relax/test_frontend_nn_tensor.py
+++ b/tests/python/relax/test_frontend_nn_tensor.py
@@ -14,15 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy as np
 import pytest
+
 import tvm
 import tvm.testing
 from tvm import relax
-from tvm.relax.frontend.nn import Tensor, Module, spec
+from tvm.relax.frontend.nn import Module, Tensor, spec
 from tvm.script import relax as R
 
-import numpy as np
-
 
 def test_tensor_from_numpy():
     x = np.random.rand(1, 10)
@@ -136,8 +136,8 @@ def test_tensor_op_datatype():
 def test_tensor_op_manipulate():
     class Model(Module):
         def test(self, x: Tensor):
-            z0 = x.reshape([2, 5, 2])
-            z1 = x.permute_dims([2, 1, 0])
+            z0 = x.reshape(2, 5, 2)
+            z1 = x.permute_dims(2, 1, 0)
             z2 = x.repeat(2, axis=1)
             return (z0, z1, z2)
 

Reply via email to