This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 59516c1949 [REFACTOR][IR] Clean up PrimType follow-ups (#19884)
59516c1949 is described below
commit 59516c1949a0032441f04bfa3146267fea27e26d
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 25 06:49:02 2026 -0400
[REFACTOR][IR] Clean up PrimType follow-ups (#19884)
---
include/tvm/ir/base_expr.h | 11 ++++++++++-
.../backend/cuda/operator/tile_primitive/elementwise/smem.py | 5 ++---
python/tvm/backend/metal/script.py | 6 +++---
python/tvm/backend/trn/transform/naive_allocator.py | 3 +--
src/ir/type.cc | 10 ----------
tests/python/tirx/test_op_namespace_cleanup.py | 8 ++++++++
6 files changed, 24 insertions(+), 19 deletions(-)
diff --git a/include/tvm/ir/base_expr.h b/include/tvm/ir/base_expr.h
index 0a844bb3ba..f1176955c7 100644
--- a/include/tvm/ir/base_expr.h
+++ b/include/tvm/ir/base_expr.h
@@ -214,7 +214,16 @@ class PrimType final : public Type {
* This uses the same packed sub-byte dtype sizing rule as runtime tensors.
* Scalable vector types have no compile-time storage size and are rejected.
*/
- TVM_DLL size_t StorageBytes() const;
+ TVM_FFI_INLINE size_t StorageBytes() const {
+ DLDataType dtype = get()->dtype;
+ int16_t encoded_lanes = static_cast<int16_t>(dtype.lanes);
+ if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) {
+ TVM_FFI_THROW(InternalError)
+ << "Cannot compute compile-time storage bytes for non-fixed vector
type " << dtype;
+ }
+ return static_cast<size_t>(
+ (static_cast<uint64_t>(dtype.bits) *
static_cast<uint64_t>(dtype.lanes) + 7) / 8);
+ }
/*! \brief Return the same type with a different dtype code, preserving bits
and lanes. */
TVM_FFI_INLINE PrimType WithCode(DLDataTypeCode code) const {
diff --git
a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py
b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py
index 8c0b97f24e..d85193c06a 100644
--- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py
+++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py
@@ -30,7 +30,6 @@ codegen time. Packed-vec emit requires the innermost dim to
have stride 1
from __future__ import annotations
-from tvm.runtime import DataType
from tvm.script import tirx as T
from tvm.tirx import PrimFunc, TilePrimitiveCall
from tvm.tirx.operator.tile_primitive import DispatchContext
@@ -100,10 +99,10 @@ def is_smem_ewise(spec):
def _max_layout_vec(plan, total: int, thread_cnt: int) -> int:
"""Widest vec_chunk dividing all operands' innermost extents AND
``total / thread_cnt``, within dtype-bit candidates
``{128,64,32,16,8}``."""
- max_bits = DataType(plan.dst.buffer.dtype.dtype).bits
+ max_bits = plan.dst.buffer.dtype.dtype.bits
for s in plan.srcs:
if s.buf_region is not None:
- max_bits = max(max_bits,
DataType(s.buf_region.buffer.dtype.dtype).bits)
+ max_bits = max(max_bits, s.buf_region.buffer.dtype.dtype.bits)
per_thread = total // thread_cnt if thread_cnt > 0 else total
if total % thread_cnt != 0:
return 1
diff --git a/python/tvm/backend/metal/script.py
b/python/tvm/backend/metal/script.py
index 7c5d45564a..8f3fe337b4 100644
--- a/python/tvm/backend/metal/script.py
+++ b/python/tvm/backend/metal/script.py
@@ -37,19 +37,19 @@ class MetalNamespace:
def simd_shuffle(var, lane):
if isinstance(var, Buffer):
var = var[0]
- return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle", var,
lane)
+ return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle", var,
lane)
@staticmethod
def simd_shuffle_up(var, delta):
if isinstance(var, Buffer):
var = var[0]
- return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_up",
var, delta)
+ return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle_up", var,
delta)
@staticmethod
def simd_shuffle_down(var, delta):
if isinstance(var, Buffer):
var = var[0]
- return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_down",
var, delta)
+ return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle_down",
var, delta)
__all__ = ["MetalNamespace"]
diff --git a/python/tvm/backend/trn/transform/naive_allocator.py
b/python/tvm/backend/trn/transform/naive_allocator.py
index 96fbe523ff..72f6284886 100644
--- a/python/tvm/backend/trn/transform/naive_allocator.py
+++ b/python/tvm/backend/trn/transform/naive_allocator.py
@@ -17,7 +17,6 @@
import functools
-from tvm import DataType
from tvm.tirx import AllocBuffer, IntImm
from tvm.tirx.buffer import Buffer
from tvm.tirx.stmt_functor import StmtVisitor
@@ -48,7 +47,7 @@ def get_buffer_size(buffer: Buffer) -> int:
raise ValueError(
f"Buffer {buffer.name} has non-constant shape. Do not know how to
allocate it."
)
- return int(num_elem * DataType(buffer.dtype.dtype).itemsize)
+ return int(num_elem * buffer.dtype.dtype.itemsize)
class AllocInfoCollector(StmtVisitor):
diff --git a/src/ir/type.cc b/src/ir/type.cc
index 3a889c49c6..0652a38792 100644
--- a/src/ir/type.cc
+++ b/src/ir/type.cc
@@ -21,7 +21,6 @@
* \file src/ir/type.cc
* \brief Common type system AST nodes throughout the IR.
*/
-#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/type.h>
@@ -136,15 +135,6 @@ PrimType PrimType::ScalableVector(DLDataTypeCode code, int
bits, int lanes) {
return PrimType(ScalableVectorDType(code, bits, lanes));
}
-size_t PrimType::StorageBytes() const {
- int16_t encoded_lanes = static_cast<int16_t>(get()->dtype.lanes);
- if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) {
- TVM_FFI_THROW(InternalError)
- << "Cannot compute compile-time storage bytes for non-fixed vector
type " << get()->dtype;
- }
- return ffi::GetDataSize(1, get()->dtype);
-}
-
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("ir.PrimType", [](DLDataType dtype) { return
PrimType(dtype); });
diff --git a/tests/python/tirx/test_op_namespace_cleanup.py
b/tests/python/tirx/test_op_namespace_cleanup.py
index 29cf3ba972..e6d71cabd4 100644
--- a/tests/python/tirx/test_op_namespace_cleanup.py
+++ b/tests/python/tirx/test_op_namespace_cleanup.py
@@ -295,17 +295,23 @@ def
test_device_intrinsic_printer_roundtrips_canonical_namespaces():
T.cuda.copy_bytes(dst, src, 16)
T.ptx.ldg32(R[0], 1, A[0], 0)
T.metal.simd_shuffle(A[0], 0)
+ T.metal.simd_shuffle_up(A[0], 1)
+ T.metal.simd_shuffle_down(A[0], 1)
calls = _expr_calls(device_namespaces)
assert [call.op.name for call in calls] == [
"tirx.cuda.copy_bytes",
"tirx.ptx.ldg32",
"tirx.metal.simd_shuffle",
+ "tirx.metal.simd_shuffle_up",
+ "tirx.metal.simd_shuffle_down",
]
for op_name, namespace in [
("tirx.cuda.copy_bytes", "cuda"),
("tirx.ptx.ldg32", "ptx"),
("tirx.metal.simd_shuffle", "metal"),
+ ("tirx.metal.simd_shuffle_up", "metal"),
+ ("tirx.metal.simd_shuffle_down", "metal"),
]:
assert _op_attr(op_name, "TIRxOpCategory") == "device_intrin"
assert _op_attr(op_name, "TDeviceIntrinsicNamespace") == namespace
@@ -315,6 +321,8 @@ def
test_device_intrinsic_printer_roundtrips_canonical_namespaces():
assert "T.cuda.copy_bytes(" in code
assert "T.ptx.ldg32(" in code
assert "T.metal.simd_shuffle(" in code
+ assert "T.metal.simd_shuffle_up(" in code
+ assert "T.metal.simd_shuffle_down(" in code
assert "T.tirx." not in code
reparsed = tvm.script.from_source(code)
assert reparsed.script() == code