This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0082836d2d [REFACTOR][RELAX] Phase out Relax PrimType (#19858)
0082836d2d is described below

commit 0082836d2d2bf55892d10644dfd2913274e95b88
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Jun 22 11:27:20 2026 -0400

    [REFACTOR][RELAX] Phase out Relax PrimType (#19858)
    
    Summary:
    - Remove the Relax-specific PrimType node/API and use canonical
    ir.PrimType for dtype-only scalar types.
    - Update parser, printer, analysis, op inference/legalization, and tests
    to avoid value-bearing PrimType semantics.
    - Preserve scalar values where needed by reading PrimValue expressions
    directly instead of storing values in the type.
---
 include/tvm/relax/type.h                           | 35 -----------
 python/tvm/relax/__init__.py                       |  1 -
 python/tvm/relax/backend/metal/coreml.py           |  3 +-
 python/tvm/relax/expr.py                           |  4 +-
 python/tvm/relax/script/parser/entry.py            | 35 +++++------
 python/tvm/relax/testing/ast_printer.py            |  2 +-
 .../tvm/relax/transform/lazy_transform_params.py   |  2 +-
 python/tvm/relax/transform/legalize_ops/index.py   | 15 +++--
 python/tvm/relax/type.py                           | 73 +---------------------
 python/tvm/relax/utils.py                          | 31 ++++-----
 src/relax/analysis/type_analysis.cc                | 56 ++---------------
 src/relax/backend/vm/vm_shape_lower.cc             |  7 ---
 src/relax/ir/block_builder.cc                      |  9 ---
 src/relax/ir/dataflow_expr_rewriter.cc             |  2 +-
 src/relax/ir/dependent_type.cc                     | 27 --------
 src/relax/ir/expr.cc                               |  2 +-
 src/relax/ir/type_functor.cc                       | 19 +-----
 src/relax/op/memory/view.cc                        |  4 +-
 src/relax/op/op.cc                                 |  2 +-
 src/relax/op/tensor/index.cc                       | 53 +++++++++-------
 src/relax/op/tensor/inspect.cc                     | 50 +++++++--------
 src/relax/script/printer/dependent_type.cc         | 17 -----
 src/relax/transform/fuse_tir.cc                    | 18 +++---
 src/relax/transform/remove_unused_parameters.cc    |  2 +-
 src/relax/utils.cc                                 | 13 ----
 src/tirx/ir/function.cc                            |  4 +-
 tests/cpp/nested_msg_test.cc                       | 10 +--
 tests/python/relax/test_analysis_type_analysis.py  | 23 +++----
 tests/python/relax/test_ast_printer.py             |  2 +-
 .../relax/test_backend_transform_shape_lower.py    |  3 +
 tests/python/relax/test_bind_symbolic_vars.py      |  1 +
 tests/python/relax/test_blockbuilder_core.py       |  4 +-
 tests/python/relax/test_blockbuilder_emit_te.py    | 14 ++---
 tests/python/relax/test_dataflow_rewriter.py       |  1 +
 tests/python/relax/test_expr.py                    |  8 +--
 tests/python/relax/test_op_binary.py               | 14 ++---
 tests/python/relax/test_op_manipulate.py           |  4 +-
 .../relax/test_transform_compute_prim_value.py     |  3 +
 .../relax/test_transform_lazy_transform_params.py  |  3 +
 .../test_transform_remove_unused_parameters.py     |  3 +
 .../test_transform_rewrite_dataflow_reshape.py     |  4 +-
 tests/python/relax/test_tvmscript_parser.py        | 64 +------------------
 tests/python/relax/test_tvmscript_printer_relax.py | 15 +++--
 tests/python/relax/test_type.py                    | 26 ++------
 tests/python/relax/test_utils.py                   |  6 +-
 tests/python/relax/test_vm_build.py                |  2 +
 tests/python/tirx-base/test_tir_specialize.py      |  4 +-
 .../python/tvmscript/test_tvmscript_parser_tir.py  | 10 +--
 48 files changed, 206 insertions(+), 504 deletions(-)

diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h
index fcaddc3ab5..9174e66bdd 100644
--- a/include/tvm/relax/type.h
+++ b/include/tvm/relax/type.h
@@ -114,41 +114,6 @@ class ObjectType : public Type {
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ObjectType, Type, 
ObjectTypeNode);
 };
 
-/*!
- * \brief Primitive value.
- */
-class PrimTypeNode : public TypeNode {
- public:
-  /*! \brief Underlying primitive value, if known */
-  ffi::Optional<PrimExpr> value;
-
-  /*! \brief Underlying data type of the primitive value */
-  DataType dtype;
-
-  static void RegisterReflection() {
-    namespace refl = tvm::ffi::reflection;
-    refl::ObjectDef<PrimTypeNode>()
-        .def_ro("value", &PrimTypeNode::value)
-        .def_ro("dtype", &PrimTypeNode::dtype);
-  }
-  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PrimType", PrimTypeNode, TypeNode);
-};
-
-/*!
- * \brief Managed reference to PrimTypeNode.
- * \sa PrimTypeNode
- */
-class PrimType : public Type {
- public:
-  /* Construct a PrimType with a known dtype, but unknown value */
-  TVM_DLL PrimType(DataType dtype, Span span = Span());
-
-  /* Construct a PrimType with a known value */
-  TVM_DLL PrimType(PrimExpr value, Span span = Span());
-
-  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimType, Type, PrimTypeNode);
-};
-
 /*!
  * \brief Type of shape value.
  */
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 3eea8b0b02..b0dca99248 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -54,7 +54,6 @@ from .expr import const, extern, get_shape_of
 from .type import (
     Type,
     ObjectType,
-    PrimType,
     ShapeType,
     TensorType,
     TupleType,
diff --git a/python/tvm/relax/backend/metal/coreml.py 
b/python/tvm/relax/backend/metal/coreml.py
index d0b7ea3fc8..0152e965b9 100644
--- a/python/tvm/relax/backend/metal/coreml.py
+++ b/python/tvm/relax/backend/metal/coreml.py
@@ -24,6 +24,7 @@ import tvm_ffi
 
 import tvm
 from tvm.contrib import coreml_runtime
+from tvm.ir import PrimType
 from tvm.relax import transform
 from tvm.relax.dpl.pattern import is_op, wildcard
 from tvm.relax.expr import (
@@ -37,7 +38,7 @@ from tvm.relax.expr import (
     VarBinding,
 )
 from tvm.relax.transform import PatternCheckContext
-from tvm.relax.type import PrimType, TensorType
+from tvm.relax.type import TensorType
 from tvm.support.xcode import compile_coreml
 
 from ...expr_functor import PyExprVisitor, visitor
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index e02bbb51ca..ac10fb45ae 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -419,7 +419,7 @@ class _DLTensorShapeProxy(tvm.runtime.ObjectConvertible):
         if not isinstance(axis, tvm.relax.Expr):
             axis = tvm.relax.PrimValue(axis)
 
-        if axis.ty is not None and not isinstance(axis.ty, tvm.relax.PrimType):
+        if axis.ty is not None and not isinstance(axis.ty, tvm.ir.PrimType):
             raise TypeError(
                 f"The index used to access {self.tensor}.shape "
                 f'must have type R.Prim("int64"), '
@@ -487,7 +487,7 @@ class _DLTensorStrideProxy(tvm.runtime.ObjectConvertible):
         if not isinstance(axis, tvm.relax.Expr):
             axis = tvm.relax.PrimValue(axis)
 
-        if axis.ty is not None and not isinstance(axis.ty, tvm.relax.PrimType):
+        if axis.ty is not None and not isinstance(axis.ty, tvm.ir.PrimType):
             raise TypeError(
                 f"The index used to access {self.tensor}.strides "
                 f'must have type R.Prim("int64"), '
diff --git a/python/tvm/relax/script/parser/entry.py 
b/python/tvm/relax/script/parser/entry.py
index cbcc07ce1b..7301a73a1a 100644
--- a/python/tvm/relax/script/parser/entry.py
+++ b/python/tvm/relax/script/parser/entry.py
@@ -20,12 +20,12 @@ from collections.abc import Callable as _Callable
 from typing import Any, TypeVar
 
 import tvm
+from tvm.ir import PrimType
 from tvm.relax import (
     Expr,
     Function,
     FuncType,
     ObjectType,
-    PrimType,
     SeqExpr,
     ShapeExpr,
     ShapeType,
@@ -444,7 +444,6 @@ def Shape(values: list[PrimExpr] | None = None, ndim: int = 
-1) -> ShapeProxy:
 
 class PrimProxy(TypeProxy):
     dtype: str | None
-    value: int | float | str | PrimExpr | None
 
     """The type of TIR-representable values.
 
@@ -453,8 +452,6 @@ class PrimProxy(TypeProxy):
     dtype : Optional[str]
        The data type.
 
-    value: Optional[Union[int, float, str, PrimExpr]]
-       The known value
     """
 
     def __init__(
@@ -462,26 +459,23 @@ class PrimProxy(TypeProxy):
         dtype: str | None = None,
         value: int | float | str | PrimExpr | None = None,
     ) -> None:
-        if dtype is None and value is None:
-            raise TypeError(
-                "R.Prim missing required argument.  Must provide either 
'dtype' or 'value'"
-            )
+        if dtype is None:
+            if isinstance(value, PrimExpr):
+                dtype = value.dtype
+            elif isinstance(value, float):
+                dtype = "float32"
+            elif value is not None:
+                dtype = "int64"
+            else:
+                raise TypeError("R.Prim missing required argument 'dtype'")
 
         self.dtype = dtype
-        self.value = value
 
     def get_symbolic_vars(self) -> set[str]:
-        if isinstance(self.value, str) and self.value.isidentifier():
-            return {self.value}
-        else:
-            return set()
+        return set()
 
     def as_ty(self, dict_globals: dict[str, Any] | None = None) -> PrimType:
-        if self.value is None:
-            return PrimType(dtype=self.dtype)
-        else:
-            value = _eval_shape(self.value, dict_globals)
-            return PrimType(dtype=self.dtype, value=value)
+        return PrimType(self.dtype)
 
 
 def Prim(
@@ -515,7 +509,10 @@ def _normalize_ty_proxy(annotation) -> TypeProxy:
     if annotation is None:
         return TupleProxy([])
     elif callable(annotation):
-        return annotation()
+        annotation = annotation()
+        if isinstance(annotation, PrimExpr):
+            return PrimProxy(annotation.dtype)
+        return annotation
     elif isinstance(annotation, TypeProxy):
         return annotation
     else:
diff --git a/python/tvm/relax/testing/ast_printer.py 
b/python/tvm/relax/testing/ast_printer.py
index eb20d8ab5d..220d113a23 100644
--- a/python/tvm/relax/testing/ast_printer.py
+++ b/python/tvm/relax/testing/ast_printer.py
@@ -281,7 +281,7 @@ class ASTPrinter(ExprFunctor):
             return self.build_ast_node("ShapeType", **fields)
         elif isinstance(ty_node, relax.ObjectType):
             return self.build_ast_node("ObjectType")
-        elif isinstance(ty_node, relax.PrimType):
+        elif isinstance(ty_node, tvm.ir.PrimType):
             return self.build_ast_node("PrimType", dtype=ty_node.dtype)
         elif isinstance(ty_node, relax.TensorType):
             fields = {}
diff --git a/python/tvm/relax/transform/lazy_transform_params.py 
b/python/tvm/relax/transform/lazy_transform_params.py
index 432426bf74..e49ad1c948 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -216,7 +216,7 @@ class LazyTransformParamsFuncCreator:
             # direct iterate over the type annotation
             for param in func.params[num_input:]:
                 for ty in unpack_ty(param.ty):
-                    if isinstance(ty, relax.PrimType | relax.ShapeType):
+                    if isinstance(ty, tvm.ir.PrimType | relax.ShapeType):
                         params.append(relax.Var("symbolic_var_holder", ty))
 
         return relax.Function(
diff --git a/python/tvm/relax/transform/legalize_ops/index.py 
b/python/tvm/relax/transform/legalize_ops/index.py
index b71d8958e0..5c7fdca1f4 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -18,11 +18,12 @@
 """Default legalization function for index operators."""
 
 from tvm import te, tirx, topi
+from tvm.ir import PrimType
 
 from ...block_builder import BlockBuilder
-from ...expr import Call, Expr
+from ...expr import Call, Expr, PrimValue, Tuple
 from ...op import tensor_to_shape
-from ...type import PrimType, ShapeType
+from ...type import ShapeType
 from .common import register_legalize
 
 
@@ -36,11 +37,17 @@ def _take(bb: BlockBuilder, call: Call) -> Expr:
 @register_legalize("relax.strided_slice")
 def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
     def _relax_tuple_to_tir(relax_tuple):
+        if isinstance(relax_tuple, Tuple):
+            output = []
+            for field in relax_tuple.fields:
+                assert isinstance(field, PrimValue)
+                output.append(field.value)
+            return output
+
         output = []
         for field in relax_tuple.ty.fields:
             assert isinstance(field, PrimType)
-            assert field.value is not None
-            output.append(field.value)
+            return None
         return output
 
     if len(call.args) == 4:
diff --git a/python/tvm/relax/type.py b/python/tvm/relax/type.py
index 3ba5a86b90..ad8f469826 100644
--- a/python/tvm/relax/type.py
+++ b/python/tvm/relax/type.py
@@ -21,10 +21,7 @@
 import tvm_ffi
 from tvm_ffi import Array
 
-import tvm
-from tvm.ir import EnvFunc, Span, TupleType, VDevice
-from tvm.runtime import DataType
-from tvm.tirx import PrimExpr
+from tvm.ir import EnvFunc, PrimExpr, Span, TupleType, VDevice
 
 from . import _ffi_api
 from .expr import Expr, ShapeExpr, Type
@@ -39,72 +36,6 @@ class ObjectType(Type):
         self.__init_handle_by_constructor__(_ffi_api.ObjectType, span)  # 
type: ignore
 
 
-@tvm_ffi.register_object("relax.PrimType")
-class PrimType(Type):
-    """Type of a primitive POD value.
-
-    Parameters
-    ----------
-    dtype_or_expr : Union[str, DataType, PrimExpr]
-
-       The data type of the prim value, or a known expression for the prim
-       value.
-    """
-
-    value: PrimExpr | None
-    dtype: str
-
-    def __init__(
-        self,
-        dtype: str | DataType | None = None,
-        value: int | float | PrimExpr | None = None,
-        span: Span = None,
-    ) -> None:
-        # Guard against incorrect usage.  For backwards compatibility,
-        # the dtype and value are in the opposite order from most
-        # usages.  While PrimType could take a single positional
-        # argument and check the type, this would require an API
-        # difference from TVMScript's PrimProxy, which cannot.
-        # (PrimProxy uses string arguments for datatype, and also for
-        # inline variable definitions when used in a function
-        # signature, and requires separate arguments to distinguish
-        # the two cases.)
-        if isinstance(dtype, PrimExpr | int | float):
-            raise TypeError(
-                f"The first positional argument of PrimType must be the 
datatype, "
-                f", but received {type(dtype)}.  "
-                f"The value can be specified as a keyword argument "
-                f"without needing specifying the dtype: "
-                f"PrimType(value=arg)."
-            )
-
-        if dtype is None and value is None:
-            raise TypeError(
-                "PrimType.__init__ missing required argument.  "
-                "Must provide either 'dtype' or 'value'"
-            )
-
-        if dtype is not None:
-            if isinstance(value, PrimExpr):
-                assert value.dtype == dtype, (
-                    "When providing both 'value' and 'dtype' to 
PrimType.__init__, "
-                    "they must be consistent with each other.  "
-                    "However, the value {value} has dtype {value.dtype}, "
-                    "but the specified dtype was {dtype}."
-                )
-            elif isinstance(value, int | float):
-                value = tvm.tirx.const(value, dtype)
-
-        # Use relax's default integer type if not otherwise specified.
-        if isinstance(value, int):
-            value = tvm.tirx.IntImm("int64", value)
-
-        if value is None:
-            self.__init_handle_by_constructor__(_ffi_api.PrimTypeFromDtype, 
dtype, span)  # type: ignore
-        else:
-            self.__init_handle_by_constructor__(_ffi_api.PrimTypeFromValue, 
value, span)  # type: ignore
-
-
 @tvm_ffi.register_object("relax.ShapeType")
 class ShapeType(Type):
     """Type of a shape value.
@@ -261,5 +192,5 @@ class FuncType(Type):
         """
 
         if isinstance(derive_func, str):
-            derive_func = tvm.ir.EnvFunc.get("tvm.relax.type.infer_view_ty")
+            derive_func = EnvFunc.get("tvm.relax.type.infer_view_ty")
         return _ffi_api.FuncTypeOpaqueFunc(ret, derive_func, purity, span)  # 
type: ignore
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 89c9ac82c1..d143b43eaf 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -29,6 +29,7 @@ import tvm_ffi
 from tvm_ffi import Array, Map
 
 import tvm
+from tvm.ir import PrimType
 
 from .. import tirx
 from ..ir import Attrs, Type, VDevice
@@ -38,7 +39,7 @@ from ..tirx import PrimExpr
 from . import _ffi_api
 from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm, te_tensor
 from .expr import Tuple as rx_Tuple
-from .type import PrimType, ShapeType, TensorType
+from .type import ShapeType, TensorType
 
 
 def metadata_partitioner(rx_txt: str) -> list[str]:
@@ -250,23 +251,23 @@ def gen_call_tir_inputs(
                     return [_convert_te_arg_helper(val) for val in arg.values]
 
                 if isinstance(arg.ty, PrimType):
-                    if arg.ty.value is None:
-                        n_args = len(create_primfunc_args)
-                        if isinstance(arg, tvm.relax.Var):
-                            name = arg.name_hint
-                        elif n_args < len(string.ascii_lowercase):
-                            name = string.ascii_lowercase[n_args]
-                        else:
-                            name = f"scalar_input_{n_args}"
+                    if isinstance(arg, PrimValue):
+                        return _convert_te_arg_helper(arg.value)
 
-                        tir_param = tirx.Var(name, arg.ty.dtype)
+                    n_args = len(create_primfunc_args)
+                    if isinstance(arg, tvm.relax.Var):
+                        name = arg.name_hint
+                    elif n_args < len(string.ascii_lowercase):
+                        name = string.ascii_lowercase[n_args]
+                    else:
+                        name = f"scalar_input_{n_args}"
 
-                        call_tir_args.append(arg)
-                        create_primfunc_args.append(tir_param)
+                    tir_param = tirx.Var(name, arg.ty.dtype)
 
-                        return tir_param
-                    else:
-                        return _convert_te_arg_helper(arg.ty.value)
+                    call_tir_args.append(arg)
+                    create_primfunc_args.append(tir_param)
+
+                    return tir_param
 
             elif isinstance(arg, list | Array):
                 return [_convert_te_arg_helper(x) for x in arg]
diff --git a/src/relax/analysis/type_analysis.cc 
b/src/relax/analysis/type_analysis.cc
index b6c272a827..33070051ae 100644
--- a/src/relax/analysis/type_analysis.cc
+++ b/src/relax/analysis/type_analysis.cc
@@ -87,8 +87,6 @@ Type TypeFromStaticType(const Type& type) {
     return ObjectType(type->span);
   } else if (const PrimTypeNode* prim_type = type.as<PrimTypeNode>()) {
     return PrimType(prim_type->dtype, prim_type->span);
-  } else if (const tvm::PrimTypeNode* prim_type = 
type.as<tvm::PrimTypeNode>()) {
-    return PrimType(prim_type->dtype, prim_type->span);
   } else if (const ShapeTypeNode* shape_type = type.as<ShapeTypeNode>()) {
     return ShapeType(shape_type->ndim, type->span);
   } else if (const TensorTypeNode* tensor_type = type.as<TensorTypeNode>()) {
@@ -127,27 +125,7 @@ class WellDefinedEraser : public TypeMutator, public 
ExprMutatorBase, public tir
                     arith::AnalyzerObj* ana)
       : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {}
 
-  Type VisitType_(const PrimTypeNode* op) final {
-    bool has_undefined = false;
-    ffi::Optional<PrimExpr> value;
-
-    if (op->value.defined()) {
-      std::swap(has_undefined_, has_undefined);
-      value = VisitPrimExpr(op->value.value());
-      std::swap(has_undefined_, has_undefined);
-    }
-
-    // erase symbolic shape if we have undefined.
-    if (!has_undefined) {
-      if (value.same_as(op->value)) {
-        return ffi::GetRef<Type>(op);
-      } else {
-        return PrimType(value.value(), op->span);
-      }
-    } else {
-      return PrimType(op->dtype, op->span);
-    }
-  }
+  Type VisitType_(const PrimTypeNode* op) final { return 
ffi::GetRef<Type>(op); }
 
   Type VisitType_(const ShapeTypeNode* op) final {
     bool has_undefined = false;
@@ -341,10 +319,7 @@ class TypeBaseChecker : public 
TypeFunctor<BaseCheckResult(const Type&, const Ty
       return BaseCheckResult::kFailL0;
     }
 
-    if (!lhs->value.defined()) return BaseCheckResult::kPass;
-    if (!rhs->value.defined()) return BaseCheckResult::kFailL2;
-
-    return PrimValueMatchCheck(lhs->value.value(), rhs->value.value());
+    return BaseCheckResult::kPass;
   }
 
   BaseCheckResult VisitType_(const ShapeTypeNode* lhs, const Type& other) 
final {
@@ -662,13 +637,7 @@ class TypeBasePreconditionCollector : public 
TypeFunctor<PrimExpr(const Type&, c
       return IntImm::Bool(false);
     }
 
-    if (lhs->value.defined() && rhs->value.defined()) {
-      return lhs->value.value() == rhs->value.value();
-    } else if (lhs->value.defined() && !rhs->value.defined()) {
-      return IntImm::Bool(false);
-    } else {
-      return IntImm::Bool(true);
-    }
+    return IntImm::Bool(true);
   }
 
   PrimExpr VisitType_(const ShapeTypeNode* lhs, const Type& other) final {
@@ -1019,19 +988,6 @@ class TypeLCAFinder : public TypeFunctor<Type(const 
Type&, const Type&)> {
       // as a result we can unify to object.
       return ObjectType(lhs->span);
     }
-    if (!lhs->value.defined() || !rhs->value.defined() ||
-        !analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) {
-      // The two values are known to contain the same dtype, but may
-      // contain different values.
-      if (!lhs->value.defined()) {
-        // If the mismatch was due to extra information in the RHS,
-        // prefer to avoid constructing a new object.
-        return ffi::GetRef<Type>(lhs);
-      } else {
-        return PrimType(lhs->dtype, lhs->span);
-      }
-    }
-
     return ffi::GetRef<Type>(lhs);
   }
 
@@ -1234,11 +1190,7 @@ class TIRVarsDetector : public TypeVisitor {
     }
   }
 
-  void VisitType_(const PrimTypeNode* prim_ty) final {
-    if (prim_ty->value.defined()) {
-      VisitPrimExpr(prim_ty->value.value());
-    }
-  }
+  void VisitType_(const PrimTypeNode* prim_ty) final {}
 
   void VisitType_(const ShapeTypeNode* shape_ty) final {
     if (shape_ty->values.defined()) {
diff --git a/src/relax/backend/vm/vm_shape_lower.cc 
b/src/relax/backend/vm/vm_shape_lower.cc
index 9ba34d6945..8cac4a12f7 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -648,13 +648,6 @@ class VMShapeLowerMutator
                 {value, DataTypeImm(op->dtype), GetErrContext(err_ctx)}, 
Attrs(), {void_ty_});
       builder_->Emit(call, "_");
     }
-    if (op->value.defined()) {
-      MatchShapeTodoItem item;
-      item.input = value;
-      item.pattern = {op->value.value()};
-      item.err_ctx = err_ctx;
-      match_todos->push_back(item);
-    }
   }
 
   void VisitType_(const ShapeTypeNode* op, Expr value, bool always_check, bool 
dynamic_only,
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index aab0fdf8b4..7a45902068 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -498,15 +498,6 @@ class BlockBuilderImpl : public BlockBuilderNode {
       }
     }
 
-    void VisitType_(const PrimTypeNode* op) final {
-      // Only collect single var defined shape. Ignore something like 
`R.Prim(value=m + 1)`
-      if (op->value.defined()) {
-        if (auto var = op->value.as<tirx::Var>()) {
-          shape_var_map_.Set(var.value(), op->value.value());
-        }
-      }
-    }
-
    private:
     ffi::Map<tirx::Var, PrimExpr> shape_var_map_;
   };
diff --git a/src/relax/ir/dataflow_expr_rewriter.cc 
b/src/relax/ir/dataflow_expr_rewriter.cc
index cb59f566e6..d8100c2563 100644
--- a/src/relax/ir/dataflow_expr_rewriter.cc
+++ b/src/relax/ir/dataflow_expr_rewriter.cc
@@ -736,7 +736,7 @@ PatternMatchingRewriter 
PatternMatchingRewriter::FromModule(IRModule mod) {
       return ExternFuncPattern(func->global_symbol);
 
     } else if (auto prim = expr.as<PrimValueNode>()) {
-      return TypePattern(WildcardPattern(), PrimType(prim->value));
+      return TypePattern(WildcardPattern(), PrimType(prim->value.dtype()));
 
     } else {
       TVM_FFI_THROW(TypeError) << "Cannot convert Relax expression of type " 
<< expr->GetTypeKey()
diff --git a/src/relax/ir/dependent_type.cc b/src/relax/ir/dependent_type.cc
index c0ee21646d..6a2034ccc2 100644
--- a/src/relax/ir/dependent_type.cc
+++ b/src/relax/ir/dependent_type.cc
@@ -32,7 +32,6 @@ namespace relax {
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   ObjectTypeNode::RegisterReflection();
-  PrimTypeNode::RegisterReflection();
   ShapeTypeNode::RegisterReflection();
   TensorTypeNode::RegisterReflection();
   FuncTypeNode::RegisterReflection();
@@ -49,32 +48,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   refl::GlobalDef().def("relax.ObjectType", [](Span span) { return 
ObjectType(span); });
 }
 
-// Prim
-PrimType::PrimType(PrimExpr value, Span span) {
-  ffi::ObjectPtr<PrimTypeNode> n = ffi::make_object<PrimTypeNode>();
-  n->dtype = value->dtype;
-  n->value = std::move(value);
-  n->span = span;
-  data_ = std::move(n);
-}
-
-PrimType::PrimType(DataType dtype, Span span) {
-  ffi::ObjectPtr<PrimTypeNode> n = ffi::make_object<PrimTypeNode>();
-  n->dtype = dtype;
-  n->value = std::nullopt;
-  n->span = span;
-  data_ = std::move(n);
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
-  namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef()
-      .def("relax.PrimTypeFromDtype",
-           [](DataType dtype, Span span) { return PrimType(dtype, span); })
-      .def("relax.PrimTypeFromValue",
-           [](PrimExpr value, Span span) { return PrimType(value, span); });
-}
-
 // Shape
 ShapeType::ShapeType(ffi::Array<PrimExpr> values, Span span) {
   ffi::ObjectPtr<ShapeTypeNode> n = ffi::make_object<ShapeTypeNode>();
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index d84ddf2bae..ab9b0e92f0 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -363,7 +363,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 PrimValue::PrimValue(PrimExpr value, Span span) {
   ffi::ObjectPtr<PrimValueNode> n = ffi::make_object<PrimValueNode>();
-  n->ty = PrimType(value);
+  n->ty = PrimType(value.dtype());
   n->value = std::move(value);
   n->span = std::move(span);
   data_ = std::move(n);
diff --git a/src/relax/ir/type_functor.cc b/src/relax/ir/type_functor.cc
index d578db704a..e173e10c03 100644
--- a/src/relax/ir/type_functor.cc
+++ b/src/relax/ir/type_functor.cc
@@ -29,11 +29,7 @@ namespace relax {
 
 void TypeVisitor::VisitType_(const ObjectTypeNode* op) {}
 
-void TypeVisitor::VisitType_(const PrimTypeNode* op) {
-  if (op->value.defined()) {
-    this->VisitTypeExprField(op->value.value());
-  }
-}
+void TypeVisitor::VisitType_(const PrimTypeNode* op) {}
 
 void TypeVisitor::VisitType_(const ShapeTypeNode* op) {
   if (op->values.defined()) {
@@ -70,18 +66,7 @@ void TypeVisitor::VisitType_(const FuncTypeNode* op) {
 
 Type TypeMutator::VisitType_(const ObjectTypeNode* op) { return 
ffi::GetRef<Type>(op); }
 
-Type TypeMutator::VisitType_(const PrimTypeNode* op) {
-  if (!op->value.defined()) {
-    return ffi::GetRef<Type>(op);
-  }
-
-  auto new_expr = VisitTypeExprField(op->value.value());
-  if (new_expr.same_as(op->value)) {
-    return ffi::GetRef<Type>(op);
-  } else {
-    return PrimType(new_expr);
-  }
-}
+Type TypeMutator::VisitType_(const PrimTypeNode* op) { return 
ffi::GetRef<Type>(op); }
 
 Type TypeMutator::VisitType_(const ShapeTypeNode* op) {
   if (!op->values.defined()) {
diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc
index 1b21432b8d..25ad9aa66d 100644
--- a/src/relax/op/memory/view.cc
+++ b/src/relax/op/memory/view.cc
@@ -135,10 +135,10 @@ Type InferTypeView(const Call& call, const BlockBuilder& 
ctx) {
           << "Operator " << call->op
           << " expects the relative_byte_offset to be a 64-bit integer, but 
received "
           << arg_relative_byte_offset << ", which has type " << ty;
-      if (prim_ty->value.defined()) {
+      if (const auto* prim_value = 
arg_relative_byte_offset.as<PrimValueNode>()) {
         // An offset of known value is applied.  The known value may
         // be dynamic.
-        return prim_ty->value.value();
+        return prim_value->value;
       } else {
         // An offset of unknown value is applied.
         return std::nullopt;
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 2e1fa02591..739517de43 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -409,7 +409,7 @@ static ffi::Optional<Type> 
InferCallTIROutputTypeFromArguments(
       TVM_FFI_ICHECK(packed_tuple_ty);
       PrimType dummy_arg_ty = [&]() {
         if (packed_tuple_ty->values) {
-          return PrimType(packed_tuple_ty->values.value()[i]);
+          return PrimType(packed_tuple_ty->values.value()[i].dtype());
         } else {
           return PrimType(DataType::Int(64));
         }
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 0f82327994..665ea24e73 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -184,10 +184,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
  *
  * A `relax::Tuple` may be provided to an operator as an in-line
  * expression, as a variable bound to known tuple within the current
- * function, as a function argument, etc.  The Type of the tuple
- * tracks the known values of any `PrimValue` elements, but it can be
- * tedious to extract.  This utility extracts the `PrimExpr` contents
- * of a `relax::Tuple`.
+ * function, as a function argument, etc.  This overload validates that
+ * the Type could contain a tuple of `PrimValue` elements.  Without a
+ * concrete tuple expression, the values are not statically known.
  *
  * If the Type cannot contain a tuple of the type specified,
  * this function will throw an exception.  (e.g. Attempting to extract
@@ -198,7 +197,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
  *
  * \param ty The Type to inspect
  *
- * \returns An array of the `PrimType`, if it can be extracted.
+ * \returns An empty array for an empty tuple, if it can be extracted.
  *     Otherwise, `std::nullopt`.
  */
 template <typename PrimType = PrimExpr,
@@ -227,12 +226,7 @@ ffi::Optional<ffi::Array<PrimType>> 
UnpackTupleOfPrimValue(ffi::Optional<Type> t
         << "The type " << ty << " cannot contain a tuple whose elements are "
         << PrimType::ContainerType::_type_key << ", because element " << i << 
" has type " << field;
 
-    if (!prim_ty->value.defined()) return std::nullopt;
-
-    ffi::Optional<PrimType> element = prim_ty->value.as<PrimType>();
-    if (!element) return std::nullopt;
-
-    output.push_back(element.value());
+    return std::nullopt;
   }
   return output;
 }
@@ -241,10 +235,9 @@ ffi::Optional<ffi::Array<PrimType>> 
UnpackTupleOfPrimValue(ffi::Optional<Type> t
  *
  * A `relax::Tuple` may be provided to an operator as an in-line
  * expression, as a variable bound to known tuple within the current
- * function, as a function argument, etc.  The Type of the tuple
- * tracks the known values of any `PrimValue` elements, but it can be
- * tedious to extract.  This utility extracts the `PrimExpr` contents
- * of a `relax::Tuple`.
+ * function, as a function argument, etc.  This utility extracts
+ * `PrimValue` contents only when the concrete tuple expression is
+ * available.
  *
  * If the Type cannot contain a tuple of the type specified,
  * this function will throw an exception.  (e.g. Attempting to extract
@@ -261,11 +254,29 @@ ffi::Optional<ffi::Array<PrimType>> 
UnpackTupleOfPrimValue(ffi::Optional<Type> t
 template <typename PrimType = PrimExpr,
           typename = std::enable_if_t<std::is_base_of_v<PrimExpr, PrimType>>>
 ffi::Optional<ffi::Array<PrimType>> UnpackTupleOfPrimValue(ffi::Optional<Expr> 
expr) {
-  if (expr) {
-    return UnpackTupleOfPrimValue<PrimType>(GetType(expr.value()));
-  } else {
-    return std::nullopt;
+  if (!expr) return std::nullopt;
+
+  const Expr& value = expr.value();
+  if (const auto* tuple = value.as<TupleNode>()) {
+    ffi::Array<PrimType> output;
+    for (size_t i = 0; i < tuple->fields.size(); i++) {
+      const Expr& field = tuple->fields[i];
+      auto prim_value = field.as<PrimValueNode>();
+      TVM_FFI_CHECK(prim_value, TypeError)
+          << "The expression " << value << " cannot contain a tuple whose 
elements are "
+          << PrimType::ContainerType::_type_key << ", because element " << i 
<< " is " << field;
+
+      TVM_FFI_CHECK(prim_value->value.template as<typename 
PrimType::ContainerType>(), TypeError)
+          << "The expression " << value << " cannot contain a tuple whose 
elements are "
+          << PrimType::ContainerType::_type_key << ", because element " << i 
<< " has value "
+          << prim_value->value;
+
+      output.push_back(Downcast<PrimType>(prim_value->value));
+    }
+    return output;
   }
+
+  return UnpackTupleOfPrimValue<PrimType>(GetType(value));
 }
 
 Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) {
@@ -315,7 +326,7 @@ Type InferTypeStridedSlice(const Call& call, const 
BlockBuilder& ctx) {
     if (!tuple) return false;
 
     return std::all_of(tuple->fields.begin(), tuple->fields.end(), [](const 
Type& field) {
-      return IsBaseOf(relax::PrimType(DataType::Int(64)), field);
+      return IsBaseOf(tvm::PrimType(DataType::Int(64)), field);
     });
   };
   auto check_tuple = [&](const char* name, Expr expr) {
@@ -454,7 +465,7 @@ InferLayoutOutput InferLayoutStridedSlice(
     existing_layout = LayoutDecision(InitialLayout(tensor_ty->ndim));
   }
 
-  auto opt_axes_tuple = UnpackTupleOfPrimValue<IntImm>(GetType(call->args[1]));
+  auto opt_axes_tuple = UnpackTupleOfPrimValue<IntImm>(call->args[1]);
   TVM_FFI_ICHECK(opt_axes_tuple) << "Layout inference of " << call->op
                                  << " requires slices to be along static axes. 
 "
                                  << "However, expression " << call
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
index ebfbccf11e..1494d407b9 100644
--- a/src/relax/op/tensor/inspect.cc
+++ b/src/relax/op/tensor/inspect.cc
@@ -51,7 +51,7 @@ TensorType GetTensorArgInfo(const Call& call) {
   return tensor_ty.value();
 }
 
-std::tuple<TensorType, PrimType> GetTensorArgInfoWithIndex(const Call& call) {
+std::tuple<TensorType, ffi::Optional<int64_t>> GetTensorArgInfoWithIndex(const 
Call& call) {
   TVM_FFI_CHECK_EQ(call->args.size(), 2, TypeError)
       << "Operator " << call->op << " expects two arguments, "
       << "but received " << call->args.size() << " arguments: " << call->args;
@@ -68,19 +68,24 @@ std::tuple<TensorType, PrimType> 
GetTensorArgInfoWithIndex(const Call& call) {
       << "Operator " << call->op << " expects arguments (tensor, axis), "
       << "but the second argument " << arg << " in expression " << call << " 
has type " << axis->ty;
 
-  auto int_imm_axis = axis_ty->value.as<IntImmNode>();
+  ffi::Optional<int64_t> int_imm_axis = std::nullopt;
+  if (const auto* prim_value = axis.as<PrimValueNode>()) {
+    if (const auto* int_imm = prim_value->value.as<IntImmNode>()) {
+      int_imm_axis = int_imm->value;
+    }
+  }
 
   if (int_imm_axis) {
-    TVM_FFI_ICHECK_GE(int_imm_axis->value, 0);
+    TVM_FFI_ICHECK_GE(int_imm_axis.value(), 0);
   }
   if (int_imm_axis && !tensor_ty->IsUnknownNdim()) {
-    TVM_FFI_CHECK_LT(int_imm_axis->value, tensor_ty->ndim, ValueError)
+    TVM_FFI_CHECK_LT(int_imm_axis.value(), tensor_ty->ndim, ValueError)
         << "Expression " << call << " attempts to access " << arg << ".shape["
-        << int_imm_axis->value << "]"
+        << int_imm_axis.value() << "]"
         << ", but " << arg << ".shape only has " << tensor_ty->ndim << " 
elements";
   }
 
-  return {ffi::GetRef<TensorType>(tensor_ty), ffi::GetRef<PrimType>(axis_ty)};
+  return {ffi::GetRef<TensorType>(tensor_ty), int_imm_axis};
 }
 
 DataType GetTensorDataType(const Call& call) { return 
GetTensorArgInfo(call)->dtype; }
@@ -106,14 +111,7 @@ tirx::PrimFunc 
GetDLTensorField(tirx::builtin::TVMStructFieldKind field, DataTyp
   return func;
 }
 
-Expr NormalizeToKnownPrimValue(const BlockBuilder&, Call call) {
-  if (auto prim_ty = call->ty.as<PrimTypeNode>()) {
-    if (prim_ty->value.defined()) {
-      return PrimValue(prim_ty->value.value());
-    }
-  }
-  return call;
-}
+Expr NormalizeToKnownPrimValue(const BlockBuilder&, Call call) { return call; }
 
 //// relax.tensor_dtype_code
 
@@ -129,7 +127,7 @@ Type InferTypeTensorDtypeCode(const Call& call, const 
BlockBuilder&) {
   if (dtype.is_void()) {
     return PrimType(dlpack_type);
   } else {
-    return PrimType(IntImm(dlpack_type, dtype.code()));
+    return PrimType(dlpack_type);
   }
 }
 
@@ -167,7 +165,7 @@ Type InferTypeTensorDtypeBits(const Call& call, const 
BlockBuilder&) {
   if (dtype.is_void()) {
     return PrimType(dlpack_type);
   } else {
-    return PrimType(IntImm(dlpack_type, dtype.bits()));
+    return PrimType(dlpack_type);
   }
 }
 
@@ -205,7 +203,7 @@ Type InferTypeTensorDtypeLanes(const Call& call, const 
BlockBuilder&) {
   if (dtype.is_void()) {
     return PrimType(dlpack_type);
   } else {
-    return PrimType(IntImm(dlpack_type, dtype.lanes()));
+    return PrimType(dlpack_type);
   }
 }
 
@@ -243,7 +241,7 @@ Type InferTypeTensorNDim(const Call& call, const 
BlockBuilder&) {
   if (ty->IsUnknownNdim()) {
     return PrimType(dlpack_type);
   } else {
-    return PrimType(IntImm(dlpack_type, ty->ndim));
+    return PrimType(dlpack_type);
   }
 }
 
@@ -277,13 +275,12 @@ Expr tensor_shape_i(Expr expr) {
 Type InferTypeTensorShape(const Call& call, const BlockBuilder&) {
   auto dlpack_type = DataType::Int(64);
 
-  auto [tensor_ty, axis_ty] = GetTensorArgInfoWithIndex(call);
+  auto [tensor_ty, int_imm_axis] = GetTensorArgInfoWithIndex(call);
 
   auto tensor_shape = tensor_ty->GetShape();
-  auto int_imm_axis = axis_ty->value.as<IntImmNode>();
 
   if (int_imm_axis && tensor_shape.defined()) {
-    return PrimType(tensor_shape.value()[int_imm_axis->value]);
+    return PrimType(tensor_shape.value()[int_imm_axis.value()].dtype());
   } else {
     return PrimType(dlpack_type);
   }
@@ -354,10 +351,9 @@ Expr tensor_stride_i(Expr expr) {
 Type InferTypeTensorStride(const Call& call, const BlockBuilder&) {
   auto dlpack_type = DataType::Int(64);
 
-  auto [tensor_ty, axis_ty] = GetTensorArgInfoWithIndex(call);
+  auto [tensor_ty, int_imm_axis] = GetTensorArgInfoWithIndex(call);
 
   auto opt_tensor_shape = tensor_ty->GetShape();
-  auto int_imm_axis = axis_ty->value.as<IntImmNode>();
 
   if (int_imm_axis && opt_tensor_shape.defined()) {
     // As of 2024-03-14, Relax does not have an explicit
@@ -374,10 +370,10 @@ Type InferTypeTensorStride(const Call& call, const 
BlockBuilder&) {
     // for any legalizable Tensor.
     auto tensor_shape = opt_tensor_shape.value();
     PrimExpr stride = IntImm::Int64(1);
-    for (size_t axis = int_imm_axis->value + 1; axis < tensor_shape.size(); 
axis++) {
+    for (size_t axis = int_imm_axis.value() + 1; axis < tensor_shape.size(); 
axis++) {
       stride = stride * tensor_shape[axis];
     }
-    return PrimType(stride);
+    return PrimType(stride.dtype());
   } else {
     return PrimType(dlpack_type);
   }
@@ -409,7 +405,7 @@ Type InferTypeTensorByteOffset(const Call& call, const 
BlockBuilder&) {
     // Relax implicitly requires that the byte offset is zero for any
     // legalizable tensor.  See InferTypeTensorStride for full
     // explanation.
-    return PrimType(IntImm(dlpack_type, 0));
+    return PrimType(dlpack_type);
   } else {
     return PrimType(dlpack_type);
   }
@@ -440,7 +436,7 @@ Type InferTypeTensorElemOffset(const Call& call, const 
BlockBuilder&) {
     // Relax implicitly requires that the element offset is zero for
     // any legalizable tensor.  See InferTypeTensorStride for
     // full explanation.
-    return PrimType(IntImm(dlpack_type, 0));
+    return PrimType(dlpack_type);
   } else {
     return PrimType(dlpack_type);
   }
diff --git a/src/relax/script/printer/dependent_type.cc 
b/src/relax/script/printer/dependent_type.cc
index ee3aa6663c..a37c21406f 100644
--- a/src/relax/script/printer/dependent_type.cc
+++ b/src/relax/script/printer/dependent_type.cc
@@ -61,22 +61,6 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& 
e_p, const IRDocsifie
   return expr_doc;
 }
 
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch<relax::PrimType>("", [](relax::PrimType n, AccessPath n_p, 
IRDocsifier d) -> Doc {
-      ffi::Array<ExprDoc, void> args;
-      ffi::Array<ffi::String> kwargs_keys;
-      ffi::Array<ExprDoc, void> kwargs_values;
-
-      if (n->value.defined()) {
-        kwargs_keys.push_back("value");
-        kwargs_values.push_back(PrintShapeVar(n->value.value(), 
n_p->Attr("value"), d));
-      } else {
-        args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")));
-      }
-
-      return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values);
-    });
-
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<relax::ShapeType>(
         "", [](relax::ShapeType n, AccessPath n_p, IRDocsifier d) -> Doc {
@@ -172,7 +156,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
         });
 
 TVM_REGISTER_SCRIPT_AS_REPR(relax::ObjectTypeNode, ReprPrintRelax);
-TVM_REGISTER_SCRIPT_AS_REPR(relax::PrimTypeNode, ReprPrintRelax);
 TVM_REGISTER_SCRIPT_AS_REPR(relax::ShapeTypeNode, ReprPrintRelax);
 TVM_REGISTER_SCRIPT_AS_REPR(relax::TensorTypeNode, ReprPrintRelax);
 TVM_REGISTER_SCRIPT_AS_REPR(relax::FuncTypeNode, ReprPrintRelax);
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 308c9da4c3..d54e8faf6e 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -980,8 +980,7 @@ class FusedTIRConstructor : public ExprVisitor {
 
     } else if (const auto* prim_value = ty.as<PrimTypeNode>()) {
       // Case 2. The relax param is a scalar, we directly create a tirx var
-      TVM_FFI_ICHECK(prim_value->value->IsInstance<tirx::VarNode>());
-      out->push_back(Downcast<tirx::Var>(prim_value->value));
+      out->push_back(tirx::Var(name_hint, prim_value->dtype));
 
     } else if (const auto* shape_expr = ty.as<ShapeTypeNode>()) {
       // Case 3. The relax param is a tuple of scalars, each represented as a 
tirx var
@@ -1255,13 +1254,14 @@ class TIRFuseMutator : public ExprMutator {
           tir_vars.push_back(prim_value);
         }
       } else if (const auto* prim_value = ty.as<PrimTypeNode>()) {
-        TVM_FFI_ICHECK(prim_value->value.defined())
-            << "FuseTIR requires all R.Prim arguments to have a known value.";
-        PrimExpr expr = prim_value->value.value();
-        TVM_FFI_ICHECK(expr->IsInstance<tirx::VarNode>())
-            << "FuseTIR currently requires all R.Prim "
-               "arguments to provide a single tirx::Var.";
-        tir_vars.push_back(expr);
+        if (const auto* literal = arg.as<PrimValueNode>()) {
+          tir_vars.push_back(literal->value);
+        } else if (const auto* var = arg.as<VarNode>()) {
+          tir_vars.push_back(tirx::Var(var->name_hint(), prim_value->dtype));
+        } else {
+          TVM_FFI_THROW(TypeError) << "FuseTIR expects scalar arguments to be 
PrimValue or Var, "
+                                   << "but received " << arg;
+        }
 
       } else {
         arg_list.push_back(arg);
diff --git a/src/relax/transform/remove_unused_parameters.cc 
b/src/relax/transform/remove_unused_parameters.cc
index 598478c9c2..218a808565 100644
--- a/src/relax/transform/remove_unused_parameters.cc
+++ b/src/relax/transform/remove_unused_parameters.cc
@@ -100,7 +100,7 @@ std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
   }
 
   for (const auto& tir_var : free_tir_vars) {
-    Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var));
+    Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var.dtype()));
     params.push_back(relax_var);
   }
 
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index d35b32ac58..370947e4b0 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -119,18 +119,6 @@ tvm::ffi::Map<tirx::Var, PrimExpr> InferSymbolicVarMap(
     }
   };
 
-  auto bind_from_prim_value = [&bind_from_prim_expr](const Type& var, const 
Type& expr) {
-    auto var_ty = var.as<PrimTypeNode>();
-    if (!var_ty) return;
-
-    auto expr_ty = expr.as<PrimTypeNode>();
-    if (!expr_ty) return;
-
-    if (!var_ty->value.defined() || !expr_ty->value.defined()) return;
-
-    bind_from_prim_expr(var_ty->value.value(), expr_ty->value.value());
-  };
-
   auto bind_from_shape = [&bind_from_prim_expr](const Type& var, const Type& 
expr) {
     auto var_shape = var.as<ShapeTypeNode>();
     if (!var_shape) return;
@@ -178,7 +166,6 @@ tvm::ffi::Map<tirx::Var, PrimExpr> InferSymbolicVarMap(
   bind_from_ty = [&](const Type& var, const Type& expr) {
     bind_from_tensor(var, expr);
     bind_from_shape(var, expr);
-    bind_from_prim_value(var, expr);
     bind_from_tuple(var, expr);
   };
 
diff --git a/src/tirx/ir/function.cc b/src/tirx/ir/function.cc
index c44c279980..d6b171481e 100644
--- a/src/tirx/ir/function.cc
+++ b/src/tirx/ir/function.cc
@@ -54,14 +54,14 @@ tvm::Type InferType(const PrimFunc& prim_func) {
         return relax::ObjectType();
       }
 
-      return relax::PrimType(param->dtype);
+      return PrimType(param->dtype);
     }();
     params.push_back(param_ty);
   }
 
   tvm::Type ret = [&]() -> tvm::Type {
     if (const auto* prim = prim_func->ret_type.as<PrimTypeNode>()) {
-      return relax::PrimType(prim->dtype);
+      return PrimType(prim->dtype);
     } else if (IsVoidType(prim_func->ret_type)) {
       return relax::TupleType(ffi::Array<tvm::Type>{});
     } else {
diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc
index 7b624304e5..07d9995bbd 100644
--- a/tests/cpp/nested_msg_test.cc
+++ b/tests/cpp/nested_msg_test.cc
@@ -145,9 +145,9 @@ TEST(NestedMsg, Equal) {
 }
 
 TEST(NestedMsg, MapAndDecompose) {
-  relax::Var x("x", relax::PrimType(runtime::DataType::Int(16)));
-  relax::Var y("y", relax::PrimType(runtime::DataType::Int(32)));
-  relax::Var z("z", relax::PrimType(runtime::DataType::Int(64)));
+  relax::Var x("x", PrimType(runtime::DataType::Int(16)));
+  relax::Var y("y", PrimType(runtime::DataType::Int(32)));
+  relax::Var z("z", PrimType(runtime::DataType::Int(64)));
 
   BlockBuilder bb = BlockBuilder::Create(std::nullopt);
   relax::Expr t0 = bb->Normalize(Tuple({x, y}));
@@ -169,7 +169,7 @@ TEST(NestedMsg, MapAndDecompose) {
                     [](IntImm lhs, IntImm rhs) -> bool { return lhs->value == 
rhs->value; }));
 
   auto output2 = MapToNestedMsg<IntImm>(GetType(t1), [&](Type ty) -> 
NestedMsg<IntImm> {
-    const auto* prim_ty = ty.as<relax::PrimTypeNode>();
+    const auto* prim_ty = ty.as<PrimTypeNode>();
     if (prim_ty == nullptr) return std::nullopt;
     int bits = prim_ty->dtype.bits();
     if (bits == 16) return c0;
@@ -306,7 +306,7 @@ TEST(NestedMsg, TransformTupleLeaf) {
   NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}};
   NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}};
 
-  relax::PrimType s = relax::PrimType(runtime::DataType::Int(32));
+  PrimType s = PrimType(runtime::DataType::Int(32));
   relax::Var x("x", s), y("y", s), z("z", s);
   BlockBuilder bb = BlockBuilder::Create(std::nullopt);
   Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, 
x})})}));
diff --git a/tests/python/relax/test_analysis_type_analysis.py 
b/tests/python/relax/test_analysis_type_analysis.py
index 20ccb4a0e7..c91c504958 100644
--- a/tests/python/relax/test_analysis_type_analysis.py
+++ b/tests/python/relax/test_analysis_type_analysis.py
@@ -35,8 +35,8 @@ def test_get_static_type_basic():
     tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0), 
rx.ObjectType())
 
     # prim
-    s1 = rx.PrimType("float32")
-    tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), 
rx.PrimType("float32"))
+    s1 = tvm.ir.PrimType("float32")
+    tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), 
tvm.ir.PrimType("float32"))
 
 
 def test_get_static_type_shape():
@@ -105,7 +105,7 @@ def test_erase_to_well_defined_basic():
     tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0)
 
     # prim
-    s1 = rx.PrimType("float32")
+    s1 = tvm.ir.PrimType("float32")
     tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1)
 
 
@@ -208,8 +208,8 @@ def test_base_check():
 
     n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64")
     obj0 = rx.ObjectType()
-    prim0 = rx.PrimType("int32")
-    prim1 = rx.PrimType("float32")
+    prim0 = tvm.ir.PrimType("int32")
+    prim1 = tvm.ir.PrimType("float32")
 
     shape0 = rx.ShapeType(ndim=-1)
     shape1 = rx.ShapeType(ndim=2)
@@ -362,7 +362,7 @@ def _check_derive(ctx, finfo, args_ty, ret):
 
 def test_derive_call_ret_type():
     obj0 = rx.ObjectType()
-    prim0 = rx.PrimType("float32")
+    prim0 = tvm.ir.PrimType("float32")
 
     n, m = tirx.Var("n0", "int64"), tirx.Var("m0", "int64")
     bb = rx.BlockBuilder()
@@ -517,8 +517,8 @@ def _check_lca(lhs, rhs, target):
 def test_type_lca():
     n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64")
     obj0 = rx.ObjectType()
-    prim0 = rx.PrimType("int32")
-    prim1 = rx.PrimType("float32")
+    prim0 = tvm.ir.PrimType("int32")
+    prim1 = tvm.ir.PrimType("float32")
 
     vdevice0 = ir.VDevice("llvm")
     vdevice1 = ir.VDevice("cuda", 0)
@@ -764,7 +764,7 @@ def test_collect_symbolic_var_from_tensor_shape():
     assert free_vars == {n, p, q}
 
 
-param_type = tvm.testing.parameter("shape_expr", "prim_value")
+param_type = tvm.testing.parameter("shape_expr")
 param_order = tvm.testing.parameter("definition_first", "usage_first")
 
 
@@ -779,11 +779,6 @@ def 
test_collect_symbolic_var_from_non_tensor_params(param_type, param_order):
         extra_params = [
             rx.Var("shape_expr", rx.ShapeType([tir_n, tir_m])),
         ]
-    elif param_type == "prim_value":
-        extra_params = [
-            rx.Var("n", rx.PrimType(value=tir_n)),
-            rx.Var("m", rx.PrimType(value=tir_m)),
-        ]
     else:
         raise ValueError(f"Unknown param_type: {param_type}")
 
diff --git a/tests/python/relax/test_ast_printer.py 
b/tests/python/relax/test_ast_printer.py
index 49c52ef888..710daf55dc 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -289,7 +289,7 @@ def test_ty():
 
     assert printer.visit_ty_(rx.ObjectType()) == "ObjectType()"
 
-    assert printer.visit_ty_(rx.PrimType("int32")) == "PrimType(dtype=int32)"
+    assert printer.visit_ty_(tvm.ir.PrimType("int32")) == 
"PrimType(dtype=int32)"
 
     # empty shape
     empty_ssi = rx.ShapeType()
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py 
b/tests/python/relax/test_backend_transform_shape_lower.py
index 045468e57a..f2bc877694 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -16,6 +16,8 @@
 # under the License.
 # ruff: noqa: F841
 
+import pytest
+
 import tvm.script
 import tvm.testing
 from tvm import relax
@@ -816,6 +818,7 @@ def test_check_weights_with_dynamic_shape():
     assert_structural_equal(after, expected)
 
 
[email protected](reason="value-bearing R.Prim annotations were removed")
 def test_update_symbolic_vars_in_match_cast_rhs():
     """Symbolic variables may be used on the RHS of match_cast"""
 
diff --git a/tests/python/relax/test_bind_symbolic_vars.py 
b/tests/python/relax/test_bind_symbolic_vars.py
index 90fe4864c7..fdc514696f 100644
--- a/tests/python/relax/test_bind_symbolic_vars.py
+++ b/tests/python/relax/test_bind_symbolic_vars.py
@@ -204,6 +204,7 @@ def test_bind_symbolic_vars_in_shape_expr():
     tvm.ir.assert_structural_equal(expected, after)
 
 
[email protected](reason="value-bearing R.Prim annotations were removed")
 def test_bind_defining_of_symbolic_vars_in_prim_value():
     """R.Prim may define symbolic variables
 
diff --git a/tests/python/relax/test_blockbuilder_core.py 
b/tests/python/relax/test_blockbuilder_core.py
index 2d2eb95ec8..34be64df5d 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -643,8 +643,8 @@ def test_emit_nested_tuple(emit_nested_tuple):
 
         n_sym = tirx.Var("n", "int64")
         m_sym = tirx.Var("m", "int64")
-        n = rx.Var("n", rx.PrimType(value=n_sym))
-        m = rx.Var("m", rx.PrimType(value=m_sym))
+        n = rx.Var("n", tvm.ir.PrimType("int64"))
+        m = rx.Var("m", tvm.ir.PrimType("int64"))
         x = rx.Var("x", rx.TensorType([n_sym, m_sym], "float32"))
         y = rx.Var("y", rx.TensorType([m_sym, n_sym], "float32"))
 
diff --git a/tests/python/relax/test_blockbuilder_emit_te.py 
b/tests/python/relax/test_blockbuilder_emit_te.py
index 0ca90e5a8b..7643d96980 100644
--- a/tests/python/relax/test_blockbuilder_emit_te.py
+++ b/tests/python/relax/test_blockbuilder_emit_te.py
@@ -75,7 +75,7 @@ def test_emit_te_with_symbolic_arg():
 
 
 def test_symbolic_shape_in_prim_value():
-    """Symbolic vars may be provided to TE in R.Prim"""
+    """Scalar Relax vars may be provided to TE as PrimFunc parameters."""
 
     def te_slice(tensor, i):
         return tvm.te.compute([tensor.shape[1]], lambda j: tensor[i, j], 
name="slice")
@@ -83,8 +83,7 @@ def test_symbolic_shape_in_prim_value():
     def from_builder():
         bb = rx.BlockBuilder()
         A = rx.Var("A", R.Tensor([16, 16], "float32"))
-        tir_i = tvm.tirx.Var("tir_i", "int64")
-        relax_i = rx.Var("relax_i", R.Prim(value=tir_i))
+        relax_i = rx.Var("relax_i", tvm.ir.PrimType("int64"))
 
         with bb.function("main", params=[A, relax_i]):
             A_sliced = bb.emit_te(te_slice, A, relax_i)
@@ -97,8 +96,8 @@ def test_symbolic_shape_in_prim_value():
         @T.prim_func(private=True, s_tir=True)
         def te_slice(
             A: T.Buffer([T.int64(16), T.int64(16)], "float32"),
-            Output: T.Buffer(T.int64(16), "float32"),
             row_index: T.int64,
+            Output: T.Buffer(T.int64(16), "float32"),
         ):
             T.func_attr({"tirx.noalias": True})
 
@@ -110,16 +109,13 @@ def test_symbolic_shape_in_prim_value():
         @R.function
         def main(
             A: R.Tensor([16, 16], "float32"),
-            arg_row_index: R.Prim(value="row_index"),
+            arg_row_index: R.Prim("int64"),
         ):
             cls = Expected
 
-            row_index = T.int64()
-
             gv = R.call_tir(
                 cls.te_slice,
-                A,
-                tir_vars=[row_index],
+                (A, arg_row_index),
                 out_ty=R.Tensor([16], "float32"),
             )
             return gv
diff --git a/tests/python/relax/test_dataflow_rewriter.py 
b/tests/python/relax/test_dataflow_rewriter.py
index 5264913909..d9a46ba3fd 100644
--- a/tests/python/relax/test_dataflow_rewriter.py
+++ b/tests/python/relax/test_dataflow_rewriter.py
@@ -366,6 +366,7 @@ def test_recursive_rewrite_rules():
     tvm.ir.assert_structural_equal(expected, after)
 
 
[email protected](reason="value-bearing R.Prim match-cast semantics were 
removed")
 def test_rewrite_of_arbitrary_dtype():
     """A pattern-match may apply to a tensor with unknown dtype
 
diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py
index 855361712e..b64804b4ac 100644
--- a/tests/python/relax/test_expr.py
+++ b/tests/python/relax/test_expr.py
@@ -271,7 +271,7 @@ def test_prim_value_with_var():
     n = tirx.Var("n", "int64")
     pv = rx.PrimValue(n)
     assert pv.value.same_as(n)
-    tvm.ir.assert_structural_equal(pv.ty, rx.PrimType(value=n))
+    tvm.ir.assert_structural_equal(pv.ty, tvm.ir.PrimType("int64"))
     _check_equal(pv, rx.PrimValue(n))
     _check_json_roundtrip(pv)
 
@@ -279,7 +279,7 @@ def test_prim_value_with_var():
 def test_prim_value_with_expr():
     n = tirx.Var("n", "int64")
     pv = rx.PrimValue(n + 1)
-    tvm.ir.assert_structural_equal(pv.ty, rx.PrimType(value=n + 1))
+    tvm.ir.assert_structural_equal(pv.ty, tvm.ir.PrimType("int64"))
     _check_equal(pv, rx.PrimValue(n + 1))
     _check_json_roundtrip(pv)
 
@@ -301,7 +301,7 @@ def test_datatype_imm():
 
 
 def test_call():
-    dtype = rx.PrimType("int32")
+    dtype = tvm.ir.PrimType("int32")
     func = rx.Var("func", rx.FuncType([dtype], dtype))
     arg = rx.Var("arg", dtype)
     call = rx.Call(func, [arg])
@@ -312,7 +312,7 @@ def test_call():
 
 def test_call_raises_error_for_invalid_function():
     """relax::Call requires the function to have FuncType"""
-    dtype = rx.PrimType("int32")
+    dtype = tvm.ir.PrimType("int32")
     func = rx.Var("func", dtype)
     arg = rx.Var("arg", dtype)
 
diff --git a/tests/python/relax/test_op_binary.py 
b/tests/python/relax/test_op_binary.py
index 953e744fb7..f5d12bbe67 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -141,7 +141,7 @@ def 
test_infer_ty_binary_arith_prim_value_with_prim_value(binary_arith_op: Calla
     x = relax.Var("x", R.Prim("float32"))
     y = relax.Var("y", R.Prim("float32"))
 
-    _check_inference(bb, binary_arith_op(x, y), relax.PrimType("float32"))
+    _check_inference(bb, binary_arith_op(x, y), tvm.ir.PrimType("float32"))
 
 
 @pytest.mark.parametrize("binary_arith_op,tir_arith_op", binary_arith_ops)
@@ -157,8 +157,8 @@ def 
test_infer_ty_binary_arith_known_prim_value_with_prim_value(
     x = relax.Var("x", R.Prim(value=tir_x))
     y = relax.Var("y", R.Prim(value=tir_y))
 
-    _check_inference(bb, binary_arith_op(x, y), relax.PrimType(value=tir_x + 
tir_y))
-    _check_inference(bb, binary_arith_op(y, x), relax.PrimType(value=tir_y + 
tir_x))
+    _check_inference(bb, binary_arith_op(x, y), tvm.ir.PrimType("float32"))
+    _check_inference(bb, binary_arith_op(y, x), tvm.ir.PrimType("float32"))
 
 
 binary_cmp_ops = [
@@ -202,8 +202,8 @@ def 
test_infer_ty_binary_cmp_prim_value_to_prim_value(binary_cmp_op: Callable):
     bb = relax.BlockBuilder()
     x = relax.Var("x", R.Prim("float32"))
     y = relax.Var("y", R.Prim("float32"))
-    _check_inference(bb, binary_cmp_op(x, y), relax.PrimType("bool"))
-    _check_inference(bb, binary_cmp_op(y, x), relax.PrimType("bool"))
+    _check_inference(bb, binary_cmp_op(x, y), tvm.ir.PrimType("bool"))
+    _check_inference(bb, binary_cmp_op(y, x), tvm.ir.PrimType("bool"))
 
 
 @pytest.mark.parametrize("binary_cmp_op,tir_cmp_op", binary_cmp_ops)
@@ -217,8 +217,8 @@ def 
test_infer_ty_binary_cmp_known_prim_value_to_prim_value(binary_cmp_op: Calla
     x = relax.Var("x", R.Prim(value=tir_x))
     y = relax.Var("y", R.Prim(value=tir_y))
 
-    _check_inference(bb, binary_cmp_op(x, y), 
relax.PrimType(value=tir_cmp_op(tir_x, tir_y)))
-    _check_inference(bb, binary_cmp_op(y, x), 
relax.PrimType(value=tir_cmp_op(tir_y, tir_x)))
+    _check_inference(bb, binary_cmp_op(x, y), tvm.ir.PrimType("bool"))
+    _check_inference(bb, binary_cmp_op(y, x), tvm.ir.PrimType("bool"))
 
 
 @pytest.mark.parametrize("binary_arith_op", [row[0] for row in 
binary_arith_ops])
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index c09a04893f..9a938b647d 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -289,7 +289,7 @@ def test_reshape_infer_ty_wrong_input_type():
     x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32")))
     x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
     ns = relax.Var("ns", relax.TensorType((120,), "float32"))
-    pv = relax.Var("pv", relax.PrimType("int64"))
+    pv = relax.Var("pv", tvm.ir.PrimType("int64"))
 
     with pytest.raises(TypeError):
         bb.normalize(relax.op.reshape(x0, (2, 3, 4, 5)))
@@ -2222,7 +2222,7 @@ def test_split_infer_ty_axis_out_of_range():
 def test_split_infer_invalid_ty_indices():
     bb = relax.BlockBuilder()
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
-    v = relax.Var("v", relax.PrimType("int64"))
+    v = relax.Var("v", tvm.ir.PrimType("int64"))
 
     with pytest.raises(TypeError):
         bb.normalize(relax.op.split(x0, [v], axis=1))
diff --git a/tests/python/relax/test_transform_compute_prim_value.py 
b/tests/python/relax/test_transform_compute_prim_value.py
index 733dbc295a..a7b89f0654 100644
--- a/tests/python/relax/test_transform_compute_prim_value.py
+++ b/tests/python/relax/test_transform_compute_prim_value.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pytest
+
 import tvm
 import tvm.testing
 from tvm.script import ir as I
@@ -82,6 +84,7 @@ def test_prim_value_in_branch_condition():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
[email protected](reason="value-bearing R.Prim annotations were removed")
 def test_prim_value_in_pure_function():
     @I.ir_module
     class Before:
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py 
b/tests/python/relax/test_transform_lazy_transform_params.py
index 5642f72a09..35482e3bc0 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -16,6 +16,7 @@
 # under the License.
 # ruff: noqa: F841
 import numpy as np
+import pytest
 
 import tvm
 import tvm.testing
@@ -751,6 +752,7 @@ def test_params_without_tuple():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
[email protected](reason="value-bearing R.Prim annotations were removed")
 def test_retain_before_num_input():
     """Only lazily load parameters after num_input"""
 
@@ -844,6 +846,7 @@ def test_get_item_callback():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
[email protected](reason="value-bearing R.Prim annotations were removed")
 def test_get_item_callback_num_attrs():
     @I.ir_module(s_tir=True)
     class Before:
diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py 
b/tests/python/relax/test_transform_remove_unused_parameters.py
index 1be2a5ac2f..4c05cbdb29 100644
--- a/tests/python/relax/test_transform_remove_unused_parameters.py
+++ b/tests/python/relax/test_transform_remove_unused_parameters.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pytest
+
 import tvm
 import tvm.testing
 from tvm.script import ir as I
@@ -54,6 +56,7 @@ def test_remove_unused_relax_parameter():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
[email protected](reason="value-bearing R.Prim annotations were removed")
 def test_replace_symbolic_variables():
     """If a parameter is only required for its symbolic variables, provide 
them directly
 
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py 
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index df774656e1..ef9ef115b1 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -705,7 +705,7 @@ def test_rewrite_dynamic_reshape():
     @I.ir_module(s_tir=True)
     class Before:
         @R.function
-        def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
+        def main(x: R.Tensor(["N", 16], dtype="float32")):
             N = T.int64()
             with R.dataflow():
                 y = R.reshape(x, [N * 4, T.int64(4)])
@@ -716,7 +716,7 @@ def test_rewrite_dynamic_reshape():
     @I.ir_module(s_tir=True)
     class Expected:
         @R.function
-        def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
+        def main(x: R.Tensor(["N", 16], dtype="float32")):
             N = T.int64()
             cls = Expected
 
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 0f251940dd..902c142b8e 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1337,7 +1337,7 @@ def test_computed_prim_value_as_branch_condition():
     if_else = func.body.blocks[0].bindings[0].value
     assert isinstance(if_else.cond, relax.PrimValue)
     tvm.ir.assert_structural_equal(N % 16 == 0, if_else.cond.value)
-    tvm.ir.assert_structural_equal(if_else.cond.ty, R.Prim(value=N % 16 == 0))
+    tvm.ir.assert_structural_equal(if_else.cond.ty, R.Prim("bool"))
 
 
 def test_tir_expr_as_branch_condition():
@@ -1409,7 +1409,7 @@ def test_computed_prim_value_as_assert_condition():
     condition = assert_op.args[0]
     assert isinstance(condition, relax.PrimValue)
     tvm.ir.assert_structural_equal(N % 16 == 0, condition.value)
-    tvm.ir.assert_structural_equal(condition.ty, R.Prim(value=N % 16 == 0))
+    tvm.ir.assert_structural_equal(condition.ty, R.Prim("bool"))
 
 
 def test_tir_expr_as_assert_condition():
@@ -1472,19 +1472,6 @@ def 
test_erase_to_well_defined_keeps_variants_exposed_by_shape_expr():
     _check(foo)
 
 
-def test_erase_to_well_defined_keeps_variants_exposed_by_prim_value():
-    @R.function
-    def foo(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")):
-        q = x
-        m, n = T.int64(), T.int64()
-        z = R.match_cast(q, R.Tensor((m, n)))
-        w = z
-        return w
-
-    assert foo.ret_ty.shape is not None
-    _check(foo)
-
-
 def test_erase_to_well_defined_infers_from_shape_expr():
     @I.ir_module(s_tir=True)
     class Module:
@@ -1510,33 +1497,6 @@ def test_erase_to_well_defined_infers_from_shape_expr():
     _check(Module)
 
 
-def test_erase_to_well_defined_infers_from_prim_value():
-    @I.ir_module(s_tir=True)
-    class Module:
-        # The subroutine's symbolic variables are only in-scope for the 
subroutine.
-        @R.function
-        def subroutine(x: R.Tensor, _m: R.Prim(value="m"), _n: 
R.Prim(value="n")) -> R.Tensor(
-            ["m", "n"]
-        ):
-            q = x
-            m, n = T.int64(), T.int64()
-            z = R.match_cast(q, R.Tensor((m, n)))
-            w = z
-            return w
-
-        # However, struct inference can make the symbolic variables in
-        # the main function to the symbolic variables in the
-        # subroutine.  Therefore, the shape of the tensor returned
-        # from main can have a well-defined shape.
-        @R.function
-        def main(x: R.Tensor, relax_m: R.Prim(value="m"), relax_n: 
R.Prim(value="n")):
-            output = Module.subroutine(x, relax_m, relax_n)
-            return output
-
-    assert Module["main"].ret_ty.shape is not None
-    _check(Module)
-
-
 def test_empty_tuple():
     @R.function
     def foo(x: R.Tuple()):
@@ -1617,26 +1577,6 @@ def test_symbolic_vars_in_shape():
     _check(baz, bb.get()["baz"])
 
 
-def test_symbolic_vars_in_prim_value():
-    """Symbolic variable may be defined in R.Prim"""
-
-    @R.function
-    def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")):
-        m = T.int64()
-        z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), 
dtype="float32"))
-        return z
-
-    m = tirx.Var("m", "int64")
-    x = relax.Var("x", relax.PrimType(value=m))
-    y = relax.Var("y", relax.TensorType([m * 2], "float32"))
-    bb = relax.BlockBuilder()
-    with bb.function("baz", (x, y)):
-        z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 
2,), dtype="float32")))
-        bb.emit_func_output(z)
-
-    _check(baz, bb.get()["baz"])
-
-
 def test_undefined_symbolic_var_raises_error():
     """An undefined symbolic variable in an error
 
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py 
b/tests/python/relax/test_tvmscript_printer_relax.py
index 012aac8c55..9b4fa20678 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -177,8 +177,8 @@ def test_object_ty():
 
 
 def test_prim_ty():
-    obj = relax.PrimType("float32")
-    _assert_print(obj, 'R.Prim("float32")')
+    obj = tvm.ir.PrimType("float32")
+    _assert_print(obj, "T.float32")
 
 
 def test_shape_ty_0():
@@ -223,7 +223,7 @@ def test_tuple_ty_empty():
 def test_tuple_ty():
     obj = relax.TupleType(
         [
-            relax.PrimType("float32"),
+            tvm.ir.PrimType("float32"),
             relax.ObjectType(),
             relax.ShapeType([1, tirx.Var("a", "int64"), 3]),
         ]
@@ -231,7 +231,7 @@ def test_tuple_ty():
     _assert_print(
         obj._relax_script(),  # pylint: disable=protected-access
         """
-R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3]))
+R.Tuple(T.float32, R.Object, R.Shape([1, a, 3]))
 """,
     )
 
@@ -239,10 +239,10 @@ R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3]))
 def test_func_ty():
     obj = relax.FuncType(
         params=[
-            relax.PrimType("float32"),
+            tvm.ir.PrimType("float32"),
             relax.ObjectType(),
             relax.ShapeType([1, tirx.Var("a", "int64"), 3]),
-            relax.PrimType(value=tirx.Var("b", "int64")),
+            tvm.ir.PrimType("int64"),
         ],
         ret=relax.TensorType(
             shape=relax.ShapeExpr([1, 2, 3]),
@@ -252,8 +252,7 @@ def test_func_ty():
     _assert_print(
         obj,
         "a = T.int64()\n"
-        "b = T.int64()\n"
-        'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3]), 
R.Prim(value=b)), '
+        "R.Callable((T.float32, R.Object, R.Shape([1, a, 3]), T.int64), "
         'R.Tensor((1, 2, 3), dtype="float32"), True)',
     )
 
diff --git a/tests/python/relax/test_type.py b/tests/python/relax/test_type.py
index 1679048df4..8490b1fbbd 100644
--- a/tests/python/relax/test_type.py
+++ b/tests/python/relax/test_type.py
@@ -67,9 +67,9 @@ def test_dyn_tensor_type():
 
 
 def test_prim_ty():
-    s0 = rx.PrimType("float32")
-    s1 = rx.PrimType("float32")
-    s2 = rx.PrimType("int32")
+    s0 = tvm.ir.PrimType("float32")
+    s1 = tvm.ir.PrimType("float32")
+    s2 = tvm.ir.PrimType("int32")
 
     _check_equal(s0, s1)
 
@@ -79,7 +79,7 @@ def test_prim_ty():
     assert s0 == s1
     assert s0 != s2
 
-    assert isinstance(s0, rx.PrimType)
+    assert isinstance(s0, tvm.ir.PrimType)
     _check_json_roundtrip(s0)
     _check_json_roundtrip(s1)
 
@@ -88,23 +88,7 @@ def test_prim_ty():
 
     # wrong API constructors
     with pytest.raises((RuntimeError, TypeError)):
-        rx.PrimType([1])
-
-
-def test_prim_ty_with_expr():
-    n = tirx.Var("n", "int64")
-    ty = rx.PrimType(value=n + 1)
-
-    _check_equal(ty, rx.PrimType(value=n + 1))
-    assert not tvm_ffi.structural_equal(ty, rx.PrimType(dtype=n.dtype))
-
-    # can turn into str
-    str(ty)
-
-    assert isinstance(ty, rx.PrimType)
-    _check_json_roundtrip(ty)
-
-    assert ty.dtype == "int64"
+        tvm.ir.PrimType([1])
 
 
 def test_shape_ty():
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index c2afcaf21b..1a15484ded 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -171,6 +171,7 @@ def test_structural_equal_of_call_nodes():
     tvm.ir.assert_structural_equal(uses_same_object_twice, 
uses_two_different_objects)
 
 
[email protected](reason="value-bearing R.Prim annotations were removed")
 def test_structural_equal_with_recursive_lambda_function():
     """A recursive lambda function may be checked for structural equality
 
@@ -263,10 +264,9 @@ def 
test_structural_equal_with_distinct_recursive_lambda_function():
         "blocks[0]",
         "bindings[0]",
         "value",
-        "true_branch",
-        "body",
-        "value",
+        "cond",
         "value",
+        "a",
     ]
 
     with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))):
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index d555f0d5a9..d04f59379f 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -556,6 +556,7 @@ def test_vm_relax_symbolic_shape_tuple(exec_mode):
         func(R.prim_value(2))
 
 
[email protected](reason="value-bearing R.Prim annotations are erased to 
dtype-only PrimType")
 def test_vm_relax_symbolic_prim_value(exec_mode):
     @I.ir_module(s_tir=True)
     class mod:
@@ -576,6 +577,7 @@ def test_vm_relax_symbolic_prim_value(exec_mode):
         func(Shape([2]))
 
 
[email protected](reason="value-bearing R.Prim annotations are erased to 
dtype-only PrimType")
 def test_vm_relax_multiple_symbolic_prim_value(exec_mode):
     """Like test_vm_relax_symbolic_prim_value, but with multiple variables"""
 
diff --git a/tests/python/tirx-base/test_tir_specialize.py 
b/tests/python/tirx-base/test_tir_specialize.py
index cecaf07ab8..0529bd90a4 100644
--- a/tests/python/tirx-base/test_tir_specialize.py
+++ b/tests/python/tirx-base/test_tir_specialize.py
@@ -347,10 +347,10 @@ def test_specialization_updates_ty():
     def expected() -> T.int32:
         T.ret(50)
 
-    ty_before = tvm.relax.FuncType([tvm.relax.PrimType("int32")], 
tvm.relax.PrimType("int32"))
+    ty_before = tvm.relax.FuncType([tvm.ir.PrimType("int32")], 
tvm.ir.PrimType("int32"))
     tvm.ir.assert_structural_equal(before.ty, ty_before)
 
-    ty_expected = tvm.relax.FuncType([], tvm.relax.PrimType("int32"))
+    ty_expected = tvm.relax.FuncType([], tvm.ir.PrimType("int32"))
     tvm.ir.assert_structural_equal(expected.ty, ty_expected)
 
     n = before.params[0]
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py 
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 5972decaa5..9c1e26459d 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -413,10 +413,10 @@ def test_inferred_ty_with_prim_args():
 
     expected = tvm.relax.FuncType(
         [
-            tvm.relax.PrimType("int32"),
-            tvm.relax.PrimType("int32"),
+            tvm.ir.PrimType("int32"),
+            tvm.ir.PrimType("int32"),
         ],
-        tvm.relax.PrimType("int32"),
+        tvm.ir.PrimType("int32"),
         purity=True,
     )
     tvm.ir.assert_structural_equal(func.ty, expected)
@@ -434,7 +434,7 @@ def test_inferred_ty_with_buffer_args():
             tvm.relax.TensorType([16, 16], "float32"),
             tvm.relax.TensorType([256], "int32"),
         ],
-        tvm.relax.PrimType("float32"),
+        tvm.ir.PrimType("float32"),
         purity=True,
     )
     tvm.ir.assert_structural_equal(func.ty, expected)
@@ -460,7 +460,7 @@ def test_inferred_ty_with_internal_allocation():
         [
             tvm.relax.TensorType([16, 16], "float32"),
         ],
-        tvm.relax.PrimType("float32"),
+        tvm.ir.PrimType("float32"),
         purity=True,
     )
     tvm.ir.assert_structural_equal(func.ty, expected)

Reply via email to