This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0b13b5c844 [Unity] Enhance Torch-consistency in rehsape (#16360)
0b13b5c844 is described below
commit 0b13b5c8445dfae5718af13faafef942d0a607fa
Author: Junru Shao <[email protected]>
AuthorDate: Sun Jan 7 08:42:38 2024 -0800
[Unity] Enhance Torch-consistency in rehsape (#16360)
This PR introduces the following signature changes:
- `Tensor.reshape(shape)` to `Tensor.reshape(*shape)`
- `Tensor.permute_dims(axes)` to `Tensor.permute_dims(*axes)`
---
python/tvm/relax/frontend/nn/_tensor_op.py | 4 ++--
tests/python/relax/test_frontend_nn_tensor.py | 10 +++++-----
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py
b/python/tvm/relax/frontend/nn/_tensor_op.py
index 627b8b626c..3a646e29b8 100644
--- a/python/tvm/relax/frontend/nn/_tensor_op.py
+++ b/python/tvm/relax/frontend/nn/_tensor_op.py
@@ -78,10 +78,10 @@ class _TensorOp:
other = _convert_scalar(other, self)
return _op().minimum(self, other)
- def reshape(self, shape):
+ def reshape(self, *shape):
return _op().reshape(self, shape)
- def permute_dims(self, axes):
+ def permute_dims(self, *axes):
return _op().permute_dims(self, axes)
def repeat(self, repeats: int, axis: Optional[int] = None):
diff --git a/tests/python/relax/test_frontend_nn_tensor.py
b/tests/python/relax/test_frontend_nn_tensor.py
index 2d2a23cc46..bd96626c95 100644
--- a/tests/python/relax/test_frontend_nn_tensor.py
+++ b/tests/python/relax/test_frontend_nn_tensor.py
@@ -14,15 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import numpy as np
import pytest
+
import tvm
import tvm.testing
from tvm import relax
-from tvm.relax.frontend.nn import Tensor, Module, spec
+from tvm.relax.frontend.nn import Module, Tensor, spec
from tvm.script import relax as R
-import numpy as np
-
def test_tensor_from_numpy():
x = np.random.rand(1, 10)
@@ -136,8 +136,8 @@ def test_tensor_op_datatype():
def test_tensor_op_manipulate():
class Model(Module):
def test(self, x: Tensor):
- z0 = x.reshape([2, 5, 2])
- z1 = x.permute_dims([2, 1, 0])
+ z0 = x.reshape(2, 5, 2)
+ z1 = x.permute_dims(2, 1, 0)
z2 = x.repeat(2, axis=1)
return (z0, z1, z2)