This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 59516c1949 [REFACTOR][IR] Clean up PrimType follow-ups (#19884)
59516c1949 is described below

commit 59516c1949a0032441f04bfa3146267fea27e26d
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 25 06:49:02 2026 -0400

    [REFACTOR][IR] Clean up PrimType follow-ups (#19884)
---
 include/tvm/ir/base_expr.h                                    | 11 ++++++++++-
 .../backend/cuda/operator/tile_primitive/elementwise/smem.py  |  5 ++---
 python/tvm/backend/metal/script.py                            |  6 +++---
 python/tvm/backend/trn/transform/naive_allocator.py           |  3 +--
 src/ir/type.cc                                                | 10 ----------
 tests/python/tirx/test_op_namespace_cleanup.py                |  8 ++++++++
 6 files changed, 24 insertions(+), 19 deletions(-)

diff --git a/include/tvm/ir/base_expr.h b/include/tvm/ir/base_expr.h
index 0a844bb3ba..f1176955c7 100644
--- a/include/tvm/ir/base_expr.h
+++ b/include/tvm/ir/base_expr.h
@@ -214,7 +214,16 @@ class PrimType final : public Type {
    * This uses the same packed sub-byte dtype sizing rule as runtime tensors.
    * Scalable vector types have no compile-time storage size and are rejected.
    */
-  TVM_DLL size_t StorageBytes() const;
+  TVM_FFI_INLINE size_t StorageBytes() const {
+    DLDataType dtype = get()->dtype;
+    int16_t encoded_lanes = static_cast<int16_t>(dtype.lanes);
+    if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) {
+      TVM_FFI_THROW(InternalError)
+          << "Cannot compute compile-time storage bytes for non-fixed vector 
type " << dtype;
+    }
+    return static_cast<size_t>(
+        (static_cast<uint64_t>(dtype.bits) * 
static_cast<uint64_t>(dtype.lanes) + 7) / 8);
+  }
 
   /*! \brief Return the same type with a different dtype code, preserving bits 
and lanes. */
   TVM_FFI_INLINE PrimType WithCode(DLDataTypeCode code) const {
diff --git 
a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py 
b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py
index 8c0b97f24e..d85193c06a 100644
--- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py
+++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py
@@ -30,7 +30,6 @@ codegen time. Packed-vec emit requires the innermost dim to 
have stride 1
 
 from __future__ import annotations
 
-from tvm.runtime import DataType
 from tvm.script import tirx as T
 from tvm.tirx import PrimFunc, TilePrimitiveCall
 from tvm.tirx.operator.tile_primitive import DispatchContext
@@ -100,10 +99,10 @@ def is_smem_ewise(spec):
 def _max_layout_vec(plan, total: int, thread_cnt: int) -> int:
     """Widest vec_chunk dividing all operands' innermost extents AND
     ``total / thread_cnt``, within dtype-bit candidates 
``{128,64,32,16,8}``."""
-    max_bits = DataType(plan.dst.buffer.dtype.dtype).bits
+    max_bits = plan.dst.buffer.dtype.dtype.bits
     for s in plan.srcs:
         if s.buf_region is not None:
-            max_bits = max(max_bits, 
DataType(s.buf_region.buffer.dtype.dtype).bits)
+            max_bits = max(max_bits, s.buf_region.buffer.dtype.dtype.bits)
     per_thread = total // thread_cnt if thread_cnt > 0 else total
     if total % thread_cnt != 0:
         return 1
diff --git a/python/tvm/backend/metal/script.py 
b/python/tvm/backend/metal/script.py
index 7c5d45564a..8f3fe337b4 100644
--- a/python/tvm/backend/metal/script.py
+++ b/python/tvm/backend/metal/script.py
@@ -37,19 +37,19 @@ class MetalNamespace:
     def simd_shuffle(var, lane):
         if isinstance(var, Buffer):
             var = var[0]
-        return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle", var, 
lane)
+        return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle", var, 
lane)
 
     @staticmethod
     def simd_shuffle_up(var, delta):
         if isinstance(var, Buffer):
             var = var[0]
-        return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_up", 
var, delta)
+        return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle_up", var, 
delta)
 
     @staticmethod
     def simd_shuffle_down(var, delta):
         if isinstance(var, Buffer):
             var = var[0]
-        return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_down", 
var, delta)
+        return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle_down", 
var, delta)
 
 
 __all__ = ["MetalNamespace"]
diff --git a/python/tvm/backend/trn/transform/naive_allocator.py 
b/python/tvm/backend/trn/transform/naive_allocator.py
index 96fbe523ff..72f6284886 100644
--- a/python/tvm/backend/trn/transform/naive_allocator.py
+++ b/python/tvm/backend/trn/transform/naive_allocator.py
@@ -17,7 +17,6 @@
 
 import functools
 
-from tvm import DataType
 from tvm.tirx import AllocBuffer, IntImm
 from tvm.tirx.buffer import Buffer
 from tvm.tirx.stmt_functor import StmtVisitor
@@ -48,7 +47,7 @@ def get_buffer_size(buffer: Buffer) -> int:
         raise ValueError(
             f"Buffer {buffer.name} has non-constant shape. Do not know how to 
allocate it."
         )
-    return int(num_elem * DataType(buffer.dtype.dtype).itemsize)
+    return int(num_elem * buffer.dtype.dtype.itemsize)
 
 
 class AllocInfoCollector(StmtVisitor):
diff --git a/src/ir/type.cc b/src/ir/type.cc
index 3a889c49c6..0652a38792 100644
--- a/src/ir/type.cc
+++ b/src/ir/type.cc
@@ -21,7 +21,6 @@
  * \file src/ir/type.cc
  * \brief Common type system AST nodes throughout the IR.
  */
-#include <tvm/ffi/container/tensor.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/ir/type.h>
@@ -136,15 +135,6 @@ PrimType PrimType::ScalableVector(DLDataTypeCode code, int 
bits, int lanes) {
   return PrimType(ScalableVectorDType(code, bits, lanes));
 }
 
-size_t PrimType::StorageBytes() const {
-  int16_t encoded_lanes = static_cast<int16_t>(get()->dtype.lanes);
-  if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) {
-    TVM_FFI_THROW(InternalError)
-        << "Cannot compute compile-time storage bytes for non-fixed vector 
type " << get()->dtype;
-  }
-  return ffi::GetDataSize(1, get()->dtype);
-}
-
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef().def("ir.PrimType", [](DLDataType dtype) { return 
PrimType(dtype); });
diff --git a/tests/python/tirx/test_op_namespace_cleanup.py 
b/tests/python/tirx/test_op_namespace_cleanup.py
index 29cf3ba972..e6d71cabd4 100644
--- a/tests/python/tirx/test_op_namespace_cleanup.py
+++ b/tests/python/tirx/test_op_namespace_cleanup.py
@@ -295,17 +295,23 @@ def 
test_device_intrinsic_printer_roundtrips_canonical_namespaces():
         T.cuda.copy_bytes(dst, src, 16)
         T.ptx.ldg32(R[0], 1, A[0], 0)
         T.metal.simd_shuffle(A[0], 0)
+        T.metal.simd_shuffle_up(A[0], 1)
+        T.metal.simd_shuffle_down(A[0], 1)
 
     calls = _expr_calls(device_namespaces)
     assert [call.op.name for call in calls] == [
         "tirx.cuda.copy_bytes",
         "tirx.ptx.ldg32",
         "tirx.metal.simd_shuffle",
+        "tirx.metal.simd_shuffle_up",
+        "tirx.metal.simd_shuffle_down",
     ]
     for op_name, namespace in [
         ("tirx.cuda.copy_bytes", "cuda"),
         ("tirx.ptx.ldg32", "ptx"),
         ("tirx.metal.simd_shuffle", "metal"),
+        ("tirx.metal.simd_shuffle_up", "metal"),
+        ("tirx.metal.simd_shuffle_down", "metal"),
     ]:
         assert _op_attr(op_name, "TIRxOpCategory") == "device_intrin"
         assert _op_attr(op_name, "TDeviceIntrinsicNamespace") == namespace
@@ -315,6 +321,8 @@ def 
test_device_intrinsic_printer_roundtrips_canonical_namespaces():
     assert "T.cuda.copy_bytes(" in code
     assert "T.ptx.ldg32(" in code
     assert "T.metal.simd_shuffle(" in code
+    assert "T.metal.simd_shuffle_up(" in code
+    assert "T.metal.simd_shuffle_down(" in code
     assert "T.tirx." not in code
     reparsed = tvm.script.from_source(code)
     assert reparsed.script() == code

Reply via email to