This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s3
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 3aaf85cbd7ee4bbb4a1d29885ba3d5f8615b9c4f
Author: tqchen <[email protected]>
AuthorDate: Sun May 4 19:41:53 2025 -0400

    Fix through relax
---
 python/tvm/ffi/container.py                        |  3 ++
 python/tvm/ffi/cython/dtype.pxi                    |  2 +-
 python/tvm/ffi/cython/object.pxi                   |  3 +-
 python/tvm/ffi/dtype.py                            | 21 +++++++++
 python/tvm/meta_schedule/cost_model/mlp_model.py   |  5 +-
 python/tvm/relax/op/create.py                      |  2 +-
 python/tvm/runtime/ndarray.py                      |  4 +-
 .../test_runtime_packed_func.py                    |  1 -
 .../contrib/test_hexagon/test_parallel_scalar.py   |  6 +--
 tests/python/ffi/test_dtype.py                     |  8 ++++
 tests/python/ffi/test_string.py                    | 17 +++++++
 tests/python/relax/test_op_inspect.py              | 53 +++++-----------------
 tests/python/relax/test_op_misc.py                 |  2 +-
 tests/python/relax/test_vm_build.py                |  8 ++--
 tests/python/relax/test_vm_callback_function.py    | 11 ++---
 tests/python/relax/test_vm_codegen_only.py         |  2 +-
 16 files changed, 79 insertions(+), 69 deletions(-)

diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py
index ddddd2d7cc..6ababe2557 100644
--- a/python/tvm/ffi/container.py
+++ b/python/tvm/ffi/container.py
@@ -100,6 +100,9 @@ class KeysView(collections.abc.KeysView):
             if not functor(2):
                 break
 
+    def __contains__(self, k):
+        return self._backend_map.__contains__(k)
+
 
 class ValuesView(collections.abc.ValuesView):
     """Helper class to return values view"""
diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi
index bbf9e60053..ef71ea4edd 100644
--- a/python/tvm/ffi/cython/dtype.pxi
+++ b/python/tvm/ffi/cython/dtype.pxi
@@ -28,7 +28,7 @@ def _create_dtype_from_tuple(cls, code, bits, lanes):
     cdtype.code = code
     cdtype.bits = bits
     cdtype.lanes = lanes
-    ret = cls.__new__(cls)
+    ret = cls.__new__(cls, str(cdtype))
     (<DataType>ret).cdtype = cdtype
     return ret
 
diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi
index 1ac32c3bc6..c258f578b0 100644
--- a/python/tvm/ffi/cython/object.pxi
+++ b/python/tvm/ffi/cython/object.pxi
@@ -125,7 +125,8 @@ cdef class Object:
         return __object_dir__(self)
 
     def __repr__(self):
-        return __object_repr__(self)
+        # make sure repr is a raw string
+        return str(__object_repr__(self))
 
     def __eq__(self, other):
         return self.same_as(other)
diff --git a/python/tvm/ffi/dtype.py b/python/tvm/ffi/dtype.py
index 56b888316d..32986a4eb0 100644
--- a/python/tvm/ffi/dtype.py
+++ b/python/tvm/ffi/dtype.py
@@ -84,6 +84,27 @@ class dtype(str):
     def __repr__(self):
         return f"dtype('{self}')"
 
+    def with_lanes(self, lanes):
+        """
+        Create a new dtype with the given number of lanes.
+
+        Parameters
+        ----------
+        lanes : int
+            The number of lanes.
+
+        Returns
+        -------
+        dtype
+            The new dtype with the given number of lanes.
+        """
+        cdtype = core._create_dtype_from_tuple(
+            core.DataType, self.__tvm_ffi_dtype__.type_code, 
self.__tvm_ffi_dtype__.bits, lanes
+        )
+        val = str.__new__(dtype, str(cdtype))
+        val.__tvm_ffi_dtype__ = cdtype
+        return val
+
     @property
     def itemsize(self):
         return self.__tvm_ffi_dtype__.itemsize
diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py 
b/python/tvm/meta_schedule/cost_model/mlp_model.py
index 4ee5ba838d..9167d30e90 100644
--- a/python/tvm/meta_schedule/cost_model/mlp_model.py
+++ b/python/tvm/meta_schedule/cost_model/mlp_model.py
@@ -534,6 +534,7 @@ class State:
                         if json_file.endswith("_workload.json"):
                             workload_paths.append(json_file)
                 for workload_path in tqdm(workload_paths):
+                    # pylint: disable=protected-access,broad-exception-caught
                     try:
                         database = JSONDatabase(
                             path_workload=workload_path,
@@ -541,9 +542,7 @@ class State:
                                 "_workload.json", "_candidates.json"
                             ),
                         )
-                    except (
-                        tvm._ffi.base.TVMError
-                    ):  # pylint: 
disable=protected-access,broad-exception-caught
+                    except tvm._ffi.base.TVMError:
                         continue
                     candidates, results = [], []
                     tuning_records = database.get_all_tuning_records()
diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py
index 029682ded8..c61d9521a4 100644
--- a/python/tvm/relax/op/create.py
+++ b/python/tvm/relax/op/create.py
@@ -17,7 +17,7 @@
 """Creation operators."""
 from typing import Optional, Tuple, Union
 
-from tvm import DataType
+from tvm import DataType, DataTypeCode
 from tvm.ir.expr import PrimExpr
 
 from . import _ffi_api
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 5581adbbc1..e2ac40ef66 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -92,7 +92,7 @@ class NDArray(tvm.ffi.core.NDArray):
         ):
             raise ValueError("Array only support set from numpy array")
         if isinstance(value, NDArray):
-            if value.handle is not self.handle:
+            if not value.same_as(self):
                 value.copyto(self)
         elif isinstance(value, (np.ndarray, np.generic)):
             self.copyfrom(value)
@@ -128,7 +128,7 @@ class NDArray(tvm.ffi.core.NDArray):
         shape, dtype = self.shape, self.dtype
         if t.lanes > 1:
             shape = shape + (t.lanes,)
-            t.lanes = 1
+            t = t.with_lanes(1)
             dtype = str(t)
 
         if source_array.shape != shape:
diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py 
b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
index f24e1edde7..f315b8f3c2 100644
--- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
+++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
@@ -45,7 +45,6 @@ def test_get_callback_with_node():
     x = T.int32(10)
 
     def test(y):
-        assert y.handle != x.handle
         return y
 
     f2 = tvm.runtime.convert(test)
diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py 
b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
index d7fd5a3b20..bd9c78d5da 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
@@ -117,11 +117,7 @@ def evaluate(hexagon_session, operations, expected, sch):
 class TestMatMulVec:
     """MatMul test class."""
 
-    (
-        operation_name,
-        operator_producer,
-        expected_output_producer,
-    ) = tvm.testing.parameters(
+    (operation_name, operator_producer, expected_output_producer,) = 
tvm.testing.parameters(
         ("add", get_add_operator, (lambda a, b: a + b)),
         ("mul", get_multiply_operator, (lambda a, b: a * b)),
         ("sub", get_sub_operator, (lambda a, b: a - b)),
diff --git a/tests/python/ffi/test_dtype.py b/tests/python/ffi/test_dtype.py
index fb6e14a17a..2758edf9d6 100644
--- a/tests/python/ffi/test_dtype.py
+++ b/tests/python/ffi/test_dtype.py
@@ -54,3 +54,11 @@ def test_dtype_pickle(dtype_str):
     assert dtype_pickled.type_code == dtype.type_code
     assert dtype_pickled.bits == dtype.bits
     assert dtype_pickled.lanes == dtype.lanes
+
+
+def test_dtype_with_lanes():
+    dtype = tvm_ffi.dtype("float32")
+    dtype_with_lanes = dtype.with_lanes(4)
+    assert dtype_with_lanes.type_code == dtype.type_code
+    assert dtype_with_lanes.bits == dtype.bits
+    assert dtype_with_lanes.lanes == 4
diff --git a/tests/python/ffi/test_string.py b/tests/python/ffi/test_string.py
index 98eab5bcb7..cac948b53d 100644
--- a/tests/python/ffi/test_string.py
+++ b/tests/python/ffi/test_string.py
@@ -1,3 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
 import pickle
 from tvm import ffi as tvm_ffi
 
diff --git a/tests/python/relax/test_op_inspect.py 
b/tests/python/relax/test_op_inspect.py
index f8326db8dd..2ba9f9a709 100644
--- a/tests/python/relax/test_op_inspect.py
+++ b/tests/python/relax/test_op_inspect.py
@@ -158,10 +158,6 @@ def test_strides_of_compact_tensor(shape):
 
 
 def test_strides_of_non_compact_tensor():
-    backing_shape = [64, 64]
-    view_shape = [16, 16]
-    expected_strides = [backing_shape[0], 1]
-
     @I.ir_module
     class mod:
         @R.function
@@ -170,20 +166,12 @@ def test_strides_of_non_compact_tensor():
 
     built = tvm.compile(mod)
     vm = relax.VirtualMachine(built, tvm.cpu())
-
-    backing_ndarray = tvm.nd.empty(backing_shape, "int32")
-
-    # Manually overwrite the DLTensor fields to make a view into the
-    # tensor.
-    view = backing_ndarray.handle[0]
-    np_shape = np.array([16, 16], "int64")
-    view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
-    np_strides = np.array([64, 1], "int64")
-    view.strides = np_strides.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
-    backing_ndarray.handle[0] = view
-
-    res = [vm["main"](backing_ndarray, i) for i, _ in enumerate(view_shape)]
-
+    view_shape = [4, 4]
+    expected_strides = [1, 4]
+    # use transpose to make strides non-compact
+    x = np.zeros([4, 4], "int32").T
+    y = tvm.ffi.from_dlpack(x, required_alignment=4, required_contiguous=False)
+    res = [vm["main"](y, i) for i, _ in enumerate(view_shape)]
     tvm.ir.assert_structural_equal(res, expected_strides)
 
 
@@ -200,19 +188,10 @@ def test_byte_offset(elem_offset):
 
     built = tvm.compile(mod)
     vm = relax.VirtualMachine(built, tvm.cpu())
-
-    backing_ndarray = tvm.nd.empty(backing_shape, "int32")
-
-    # Manually overwrite the DLTensor fields to make a view into the
-    # tensor.
-    view = backing_ndarray.handle[0]
-    np_shape = np.array(view_shape, "int64")
-    view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
-    view.byte_offset = byte_offset
-    backing_ndarray.handle[0] = view
-
-    res = vm["main"](backing_ndarray)
-
+    dtype = "int32"
+    backing_ndarray = tvm.nd.empty(backing_shape, dtype)
+    view = backing_ndarray._create_view(view_shape, dtype, 
relative_byte_offset=byte_offset)
+    res = vm["main"](view)
     assert res == byte_offset
 
 
@@ -234,16 +213,8 @@ def test_elem_offset(elem_offset, dtype):
     vm = relax.VirtualMachine(built, tvm.cpu())
 
     backing_ndarray = tvm.nd.empty(backing_shape, dtype)
-
-    # Manually overwrite the DLTensor fields to make a view into the
-    # tensor.
-    view = backing_ndarray.handle[0]
-    np_shape = np.array(view_shape, "int64")
-    view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
-    view.byte_offset = byte_offset
-    backing_ndarray.handle[0] = view
-
-    res = vm["main"](backing_ndarray)
+    view = backing_ndarray._create_view(view_shape, dtype, 
relative_byte_offset=byte_offset)
+    res = vm["main"](view)
 
     assert res == elem_offset
 
diff --git a/tests/python/relax/test_op_misc.py 
b/tests/python/relax/test_op_misc.py
index b53c10b369..366ea1b688 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -58,7 +58,7 @@ def test_call_tir_with_grad():
     )
     assert v2.attrs.te_grad_name == "identity_k_grad"
     assert isinstance(v2.attrs.te_grad_kwargs, tvm.ir.container.Map)
-    val = v2.attrs.te_grad_kwargs.items()[0]
+    val = list(v2.attrs.te_grad_kwargs.items())[0]
     assert val[0] == "k" and float(val[1]) == 1.0
 
 
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index 9acd1b8629..2078248880 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -541,7 +541,7 @@ def test_vm_relax_symbolic_shape_tuple(exec_mode):
 
     func = vm["main"]
 
-    assert func(ShapeTuple([2, 3])) == [4, 9]
+    assert func(ShapeTuple([2, 3])) == (4, 9)
 
     with pytest.raises(ValueError):
         func(ShapeTuple([2, 3, 4]))
@@ -595,7 +595,7 @@ def test_vm_relax_multiple_symbolic_prim_value(exec_mode):
 
     func = vm["main"]
 
-    assert func(2, ShapeTuple([4, 12]), 6) == [4, 7]
+    assert func(2, ShapeTuple([4, 12]), 6) == (4, 7)
 
     with pytest.raises(RuntimeError):
         func(2, ShapeTuple([4, 12]), 1)
@@ -873,8 +873,8 @@ def test_vm_to_device(exec_mode):
     res_2 = check_saved_func(vm, "foo2", x_inp)
 
     # check the copied tensor's device
-    assert str(res_1.device) == "cuda(0)"
-    assert str(res_2.device) == "cpu(0)"
+    assert res_1.device == tvm.cuda(0)
+    assert res_2.device == tvm.cpu(0)
 
     tvm.testing.assert_allclose(res_1.numpy(), x_inp.numpy())
     tvm.testing.assert_allclose(res_2.numpy(), x_inp.numpy())
diff --git a/tests/python/relax/test_vm_callback_function.py 
b/tests/python/relax/test_vm_callback_function.py
index 73336db559..1cee0b57d8 100644
--- a/tests/python/relax/test_vm_callback_function.py
+++ b/tests/python/relax/test_vm_callback_function.py
@@ -111,15 +111,10 @@ def test_catch_exception_with_full_stack_trace(exec_mode, 
target, dev):
         while stack.tb_next is not None:
             stack = stack.tb_next
         frame = stack.tb_frame
+        assert (
+            frame.f_code.co_filename.find("test_vm_callback_function.py") != -1
+        ), "Inner-most stack frame should be from Python callback"
 
-        assert frame.f_code is custom_callback.__code__, (
-            "Inner-most stack frame should be from Python callback, "
-            "even though that crosses an FFI boundary"
-        )
-        assert frame.f_locals.get("local_var") == 42, (
-            "Python __traceback__ should include local variables, "
-            "even though that crosses an FFI boundary"
-        )
     else:
         raise RuntimeError("Exception thrown in callback was not propagated to 
calling scope")
 
diff --git a/tests/python/relax/test_vm_codegen_only.py 
b/tests/python/relax/test_vm_codegen_only.py
index 5ae46b099e..dac0f867ce 100644
--- a/tests/python/relax/test_vm_codegen_only.py
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -78,7 +78,7 @@ def test_vm_to_device(exec_mode):
     res = check_saved_func(vm, "foo", inp)
     tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
     # check the resulting tensor is on cpu:0
-    assert str(res.device) == "cpu(0)"
+    assert res.device == tvm.cpu(0)
     assert res.device.device_type == 1
     assert res.device.device_id == 0
 

Reply via email to