This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 520ffd650eba7c1ab6d295e4b9e9c0aa4a14a5d2
Author: tqchen <[email protected]>
AuthorDate: Tue Apr 22 12:39:39 2025 -0400

    fix structural hash and relax tir codegen
---
 include/tvm/runtime/container/shape_tuple.h      |  4 +++-
 src/contrib/msc/core/printer/msc_base_printer.cc |  3 +--
 src/node/structural_hash.cc                      |  5 ++---
 src/tir/transforms/arg_binder.cc                 |  2 +-
 src/tir/transforms/make_packed_api.cc            | 11 ++++++-----
 tests/python/relax/test_vm_codegen_tir.py        |  2 +-
 6 files changed, 14 insertions(+), 13 deletions(-)

diff --git a/include/tvm/runtime/container/shape_tuple.h 
b/include/tvm/runtime/container/shape_tuple.h
index 6a0497049f..61f44d30be 100644
--- a/include/tvm/runtime/container/shape_tuple.h
+++ b/include/tvm/runtime/container/shape_tuple.h
@@ -24,10 +24,12 @@
 #ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
 #define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
 
+#include <tvm/ffi/container/shape.h>
+
 #include <ostream>
 #include <utility>
 #include <vector>
-#include <tvm/ffi/container/shape.h>
+
 #include "./base.h"
 
 namespace tvm {
diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc 
b/src/contrib/msc/core/printer/msc_base_printer.cc
index dac732aaa6..838d284d13 100644
--- a/src/contrib/msc/core/printer/msc_base_printer.cc
+++ b/src/contrib/msc/core/printer/msc_base_printer.cc
@@ -114,8 +114,7 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) {
       output_ << float_imm->value;
     }
   } else if (const auto* string_obj = value.as<StringObj>()) {
-    output_ << "\"" << tvm::support::StrEscape(string_obj->data, 
string_obj->size)
-            << "\"";
+    output_ << "\"" << tvm::support::StrEscape(string_obj->data, 
string_obj->size) << "\"";
   } else {
     LOG(FATAL) << "TypeError: Unsupported literal value type: " << 
value.GetTypeKey();
   }
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 4c864d6613..50b8cb4e9b 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -316,8 +316,7 @@ struct StringObjTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
 
   static void SHashReduce(const runtime::StringObj* key, SHashReducer 
hash_reduce) {
-    hash_reduce->SHashReduceHashedValue(
-        ffi::details::StableHashBytes(key->data, key->size));
+    
hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, 
key->size));
   }
 
   static bool SEqualReduce(const runtime::StringObj* lhs, const 
runtime::StringObj* rhs,
@@ -497,7 +496,7 @@ struct ShapeTupleObjTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
 
   static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) 
{
-    hash_reduce(self->size);
+    hash_reduce(static_cast<uint64_t>(self->size));
     for (uint32_t i = 0; i < self->size; ++i) {
       hash_reduce(self->data[i]);
     }
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 9270a14df9..5b9e005b7e 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -151,7 +151,7 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, 
builtin::TVMStructFieldKind kin
 
 void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
                              const PrimExpr& device_id, const Var& handle,
-                            const std::string& arg_name) {
+                             const std::string& arg_name) {
   const DataType tvm_shape_type = DataType::ShapeIndex();
   const DataType tvm_ndim_type = DataType::Int(32);
   const Stmt nop = Evaluate(0);
diff --git a/src/tir/transforms/make_packed_api.cc 
b/src/tir/transforms/make_packed_api.cc
index d1931ebced..d241d43c19 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -293,9 +293,11 @@ PrimFunc MakePackedAPI(PrimFunc func) {
       // if type_index is NDArray, we need to add the offset of the DLTensor 
header
       // which always equals 16 bytes, this ensures that T.handle always shows 
up as a DLTensor*
       arg_value = f_load_arg_value(param.dtype(), i);
-      PrimExpr handle_from_ndarray = Call(DataType::Handle(), 
tir::builtin::handle_add_byte_offset(),
-                                      {arg_value, IntImm(DataType::Int(32), 
16)});
-      arg_value = Select(type_index == ffi::TypeIndex::kTVMFFINDArray, 
handle_from_ndarray, arg_value);
+      PrimExpr handle_from_ndarray =
+          Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(),
+               {arg_value, IntImm(DataType::Int(32), 16)});
+      arg_value =
+          Select(type_index == ffi::TypeIndex::kTVMFFINDArray, 
handle_from_ndarray, arg_value);
     } else if (dtype.is_bool()) {
       std::ostringstream msg;
       msg << name_hint << ": Expect arg[" << i << "] to be boolean";
@@ -348,8 +350,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
   }
 
   for (const auto& [var, buffer] : buffer_def) {
-    binder.BindDLTensor(buffer, device_type, device_id, var,
-                        name_hint + "." + var->name_hint);
+    binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + 
var->name_hint);
     arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
   }
 
diff --git a/tests/python/relax/test_vm_codegen_tir.py 
b/tests/python/relax/test_vm_codegen_tir.py
index 60f096585d..41f8e81735 100644
--- a/tests/python/relax/test_vm_codegen_tir.py
+++ b/tests/python/relax/test_vm_codegen_tir.py
@@ -89,7 +89,7 @@ def test_tir_call():
         def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: 
T.handle):
             T.func_attr({"global_symbol": "__vmtir__foo"})
             T.call_cpacked(
-                "shape_func", T.anylist_getitem(r, T.int32(0)), 
T.reinterpret("handle", T.uint64(0))
+                "shape_func", T.anylist_getitem(r, T.int32(0))
             )
             T.anylist_setitem_call_packed(
                 r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, 
T.int32(0))

Reply via email to