This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 05e2bc3340 [Relax] Implement R.ensure_zero_offset and update memory 
planning for R.view (#17145)
05e2bc3340 is described below

commit 05e2bc3340d1c0ca505e8a66bee29ffd5d294379
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Aug 6 07:13:49 2024 -0700

    [Relax] Implement R.ensure_zero_offset and update memory planning for 
R.view (#17145)
    
    Previously, `R.view` was legalized to extern call to
    `runtime.TVMArrayCreateView` during `LegalizeOps`. This call to extern
    func can't be properly handled by `StaticBlockPlanMemory` because it
    assumes the extern func does not retain the input buffer. Extern func
    returning a view of the input would break the ref count of the
    buffer. This PR defers the legalization of `R.view` so that it can be
    explicitly handled by memory planning.
    
    A new op `R.ensure_aligned` is added as discussed in #16955
---
 include/tvm/relax/backend.h                        |  2 +-
 include/tvm/relax/op_attr_types.h                  |  9 ++++
 include/tvm/runtime/device_api.h                   |  5 ++
 python/tvm/relax/op/memory/__init__.py             |  2 +-
 python/tvm/relax/op/memory/view.py                 | 17 +++++++
 python/tvm/relax/pipeline.py                       |  2 +-
 python/tvm/relax/transform/__init__.py             |  9 ++--
 python/tvm/relax/transform/transform.py            | 17 ++++++-
 ...m_builtin_lower.cc => lower_runtime_builtin.cc} | 26 ++++++----
 src/relax/op/memory/view.cc                        | 35 +++++++++++--
 src/relax/op/memory/view.h                         |  3 ++
 src/relax/transform/static_plan_block_memory.cc    | 13 +++--
 src/runtime/cpu_device_api.cc                      |  2 +
 src/runtime/cuda/cuda_device_api.cc                |  2 +
 src/runtime/relax_vm/builtin.cc                    | 19 ++++++++
 tests/python/relax/test_op_view.py                 | 31 ++++++------
 .../test_transform_static_plan_block_memory.py     | 57 +++++++++++++++++++++-
 tests/python/relax/test_vm_builtin_lower.py        |  4 +-
 18 files changed, 211 insertions(+), 44 deletions(-)

diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h
index 2fb11f5a6f..e7d13c47b2 100644
--- a/include/tvm/relax/backend.h
+++ b/include/tvm/relax/backend.h
@@ -35,7 +35,7 @@ namespace transform {
  *
  * \return The Pass.
  */
-TVM_DLL Pass VMBuiltinLower();
+TVM_DLL Pass LowerRuntimeBuiltin();
 
 /*!
  * \brief Lower the shape expression in relax to VM shape heap and TIR 
functions.
diff --git a/include/tvm/relax/op_attr_types.h 
b/include/tvm/relax/op_attr_types.h
index b44c4582d8..291bee597c 100644
--- a/include/tvm/relax/op_attr_types.h
+++ b/include/tvm/relax/op_attr_types.h
@@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc<Expr(const 
BlockBuilder& bb, Call ca
  */
 using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const 
Call& call)>;
 
+/*! \brief The function type of a function to lower the runtime builtin.
+ *
+ * A builtin function may be lowered to a lowered form in 
`LowerRuntimeBuiltin`.
+ *
+ * \param bb The BlockBuilder context.
+ * \param call The call to be lowered.
+ */
+using FLowerBuiltin = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, 
const Call& call)>;
+
 /*!
  * \brief Gradient for a specific op.
  *
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 14b2b84b0d..c33606d98e 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -240,6 +240,11 @@ class TVM_DLL DeviceAPI {
     return device_type != kDLCPU && device_type != kDLMicroDev;
   }
 
+  /*!
+   * \brief Whether pointer arithmetics on a device owned pointer may be 
performed on the host.
+   */
+  virtual bool SupportsDevicePointerArithmeticsOnHost() { return false; }
+
  protected:
   /*!
    * \brief copy data from one place to another
diff --git a/python/tvm/relax/op/memory/__init__.py 
b/python/tvm/relax/op/memory/__init__.py
index 422c5d2e1f..1191550085 100644
--- a/python/tvm/relax/op/memory/__init__.py
+++ b/python/tvm/relax/op/memory/__init__.py
@@ -17,4 +17,4 @@
 """Relax memory primitives."""
 
 from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
-from .view import view
+from .view import view, ensure_zero_offset
diff --git a/python/tvm/relax/op/memory/view.py 
b/python/tvm/relax/op/memory/view.py
index 0c3d8a03b2..95adc78209 100644
--- a/python/tvm/relax/op/memory/view.py
+++ b/python/tvm/relax/op/memory/view.py
@@ -92,3 +92,20 @@ def view(
     relative_byte_offset = _normalize(relative_byte_offset, PrimValue)
 
     return _ffi_api.view(data, shape, dtype, relative_byte_offset)  # type: 
ignore
+
+
+def ensure_zero_offset(data: Expr) -> Expr:
+    """
+    Ensure the tensor has elem_offset == 0. A copy will be made if necessary.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input tensor
+
+    Results
+    -------
+    result : relax.Expr
+        The tensor with elem_offset == 0
+    """
+    return _ffi_api.ensure_zero_offset(data)  # type: ignore
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index d068f800d0..38242ff4d2 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -92,7 +92,7 @@ def default_build_pipeline():
                 transform.RewriteCUDAGraph(),
                 transform.LowerAllocTensor(),
                 transform.KillAfterLastUse(),
-                transform.VMBuiltinLower(),
+                transform.LowerRuntimeBuiltin(),
                 transform.ComputePrimValue(),
                 transform.VMShapeLower(),
                 transform.AttachGlobalSymbol(),
diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 5789e2fcf2..1ce864651c 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -55,6 +55,7 @@ from .transform import (
     LegalizeOps,
     LiftTransformParams,
     LowerAllocTensor,
+    LowerRuntimeBuiltin,
     MergeCompositeFunctions,
     MetaScheduleApplyDatabase,
     MetaScheduleTuneIRMod,
@@ -64,8 +65,8 @@ from .transform import (
     PatternCheckContext,
     RealizeVDevice,
     RemovePurityChecking,
-    RemoveUnusedParameters,
     RemoveUnusedOutputs,
+    RemoveUnusedParameters,
     ReorderPermuteDimsAfterConcat,
     ReorderTakeAfterMatmul,
     RewriteCUDAGraph,
@@ -84,14 +85,14 @@ from .transform import (
     function_pass,
 )
 
+from .attach_external_modules import AttachExternModules
+from .fast_math import FastMathTransform
+from .fuse_transpose_matmul import FuseTransposeMatmul
 from .ipc_allreduce_rewrite import IPCAllReduceRewrite
 from .lazy_transform_params import LazyTransformParams
 from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage
 from .optimize_layout_transform import OptimizeLayoutTransform
 from .remove_redundant_reshape import RemoveRedundantReshape
-from .fast_math import FastMathTransform
-from .fuse_transpose_matmul import FuseTransposeMatmul
-from .attach_external_modules import AttachExternModules
 
 # Import to register the legalization functions.
 from . import legalize_ops, tuning_api
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 3528b4429e..2546284625 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -19,6 +19,7 @@
 import functools
 import inspect
 import types
+import warnings
 from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, 
Union
 
 import numpy as np  # type: ignore
@@ -586,6 +587,16 @@ def ComputePrimValue() -> tvm.ir.transform.Pass:
     return _ffi_api.ComputePrimValue()  # type: ignore
 
 
+def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass:
+    """Lowering generic intrinsic to VM intrinsics.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.LowerRuntimeBuiltin()  # type: ignore
+
+
 def VMBuiltinLower() -> tvm.ir.transform.Pass:
     """Lowering generic intrinsic to VM intrinsics.
 
@@ -593,7 +604,11 @@ def VMBuiltinLower() -> tvm.ir.transform.Pass:
     -------
     ret: tvm.ir.transform.Pass
     """
-    return _ffi_api.VMBuiltinLower()  # type: ignore
+    warnings.warn(
+        "tvm.relax.transform.VMBuiltinLower has been renamed to 
'LowerRuntimeBuiltin'.  "
+        "This wrapper is for backwards compatibility, and will be removed in a 
later update."
+    )
+    return _ffi_api.LowerRuntimeBuiltin()  # type: ignore
 
 
 def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc 
b/src/relax/backend/vm/lower_runtime_builtin.cc
similarity index 90%
rename from src/relax/backend/vm/vm_builtin_lower.cc
rename to src/relax/backend/vm/lower_runtime_builtin.cc
index 887998d004..a3867ae924 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/lower_runtime_builtin.cc
@@ -17,13 +17,14 @@
  * under the License.
  */
 /*!
- * \file src/relax/backend/vm/vm_builtin_lower.cc
+ * \file src/relax/backend/vm/lower_runtime_builtin.cc
  * \brief Lowers most builtin functions and packed calls.
  */
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/attrs/op.h>
 #include <tvm/relax/backend.h>
 #include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
 #include <tvm/relax/type.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/tir/op.h>
@@ -33,11 +34,12 @@ namespace relax {
 
 // This pass lowers most ops to VM specific builtins.
 // TODO(relax-team): revisit after PrimValue.
-class VMBuiltinLowerMutator : public ExprMutator {
+class LowerRuntimeBuiltinMutator : public ExprMutator {
  public:
   using ExprMutator::VisitExpr_;
 
   Expr VisitExpr_(const CallNode* call_node) final {
+    static const auto& lower_builtin_fmap = 
Op::GetAttrMap<FLowerBuiltin>("FLowerBuiltin");
     // post-order mutation
     Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
 
@@ -64,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator {
       return MakeMemAllocTensor(call);
     } else if (call->op == mem_kill_storage_op_ || call->op == 
mem_kill_tensor_op_) {
       return MakeMemKillObject(call);
-    } else {
-      return call;
+    } else if (const auto* op_node = call->op.as<OpNode>()) {
+      Op op = GetRef<Op>(op_node);
+      if (lower_builtin_fmap.count(op)) {
+        return lower_builtin_fmap[op](builder_, call);
+      }
     }
+    return call;
   }
 
   Expr MakeMemAllocStorage(const Call& call) {
@@ -210,17 +216,19 @@ class VMBuiltinLowerMutator : public ExprMutator {
   const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
 };
 
-Expr VMBuiltinLower(const Expr& e) { return 
VMBuiltinLowerMutator().VisitExpr(e); }
+Expr LowerRuntimeBuiltin(const Expr& e) { return 
LowerRuntimeBuiltinMutator().VisitExpr(e); }
 
 namespace transform {
 
-Pass VMBuiltinLower() {
+Pass LowerRuntimeBuiltin() {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
-      [=](Function f, IRModule m, PassContext pc) { return 
Downcast<Function>(VMBuiltinLower(f)); };
-  return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(LowerRuntimeBuiltin(f));
+      };
+  return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {});
 }
 
-TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
+TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin);
 
 }  // namespace transform
 }  // namespace relax
diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc
index e7634c7edf..21a72f6200 100644
--- a/src/relax/op/memory/view.cc
+++ b/src/relax/op/memory/view.cc
@@ -291,7 +291,7 @@ StructInfo InferStructInfoView(const Call& call, const 
BlockBuilder& ctx) {
 
 
TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView);
 
-Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
+Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) {
   Expr data = call->args[0];
   Expr shape = call->args[1];
   Expr dtype = call->args[2];
@@ -352,8 +352,37 @@ TVM_REGISTER_OP("relax.memory.view")
                   "The view's byte offset, relative to the input tensor's byte 
offset.")
     .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoView)
-    .set_attr<FLegalize>("FLegalize", LegalizeView)
-    .set_attr<Bool>("FPurity", Bool(true));
+    .set_attr<Bool>("FPurity", Bool(true))
+    .set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinView);
+
+Expr ensure_zero_offset(const Expr& x) {
+  static const Op& op = Op::Get("relax.memory.ensure_zero_offset");
+  return Call(op, {x});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset);
+
+StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const 
BlockBuilder& ctx) {
+  if (call->args.size() != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Operator " << call->op << " should receive 1 
argument, "
+                     << "but received " << call->args);
+  }
+  return GetStructInfo(call->args[0]);
+}
+
+Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) {
+  const ExternFunc 
builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"};
+  return Call(builtin_ensure_zero_offset_, call->args, Attrs(), 
{GetStructInfo(call)});
+}
+
+TVM_REGISTER_OP("relax.memory.ensure_zero_offset")
+    .set_num_inputs(1)
+    .add_argument("x", "Tensor", "The input tensor.")
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoEnsureZeroOffset)
+    .set_attr<Bool>("FPurity", Bool(true))
+    .set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinEnsureZeroOffset);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/memory/view.h b/src/relax/op/memory/view.h
index bc8002fa5b..77ec7e9833 100644
--- a/src/relax/op/memory/view.h
+++ b/src/relax/op/memory/view.h
@@ -32,6 +32,9 @@ namespace relax {
 /*! \brief View a tensor with different properties. */
 Expr view(Expr x, Optional<Expr> shape, Optional<Expr> dtype, Optional<Expr> 
relative_byte_offset);
 
+/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if 
necessary. */
+Expr ensure_aligned(const Expr& x);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 2b16d86509..74200526b6 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -286,8 +286,13 @@ class TokenAllocator1D {
   std::vector<StorageToken> full_pool_;
 };
 
-/*! \brief Check if the input op is "relax.reshape". */
-bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
+/*! \brief Check if the input op is a memory op that may return the same 
buffer. */
+bool IsInplaceMemoryOp(const Expr& op) {
+  static const Op& reshape_op = Op::Get("relax.reshape");
+  static const Op& view_op = Op::Get("relax.memory.view");
+  static const Op& ensure_zero_offset_op = 
Op::Get("relax.memory.ensure_zero_offset");
+  return op.same_as(reshape_op) || op.same_as(view_op) || 
op.same_as(ensure_zero_offset_op);
+}
 
 /*! \brief The base class for the storage allocation visitor. */
 class StorageAllocatorBaseVisitor : public ExprVisitor {
@@ -498,7 +503,7 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
       // Create a storage token for builtin alloc_tensor.
       this->CreateToken(call);
       return;
-    } else if (IsReshape(call->op)) {
+    } else if (IsInplaceMemoryOp(call->op)) {
       // Reuse the input's token for builtin reshape.
       SetTokens(call, GetTokens(call->args[0]));
       return;
@@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor 
{
         block_tokens.push_back(new_token.get());
       }
       return;
-    } else if (IsReshape(call->op)) {
+    } else if (IsInplaceMemoryOp(call->op)) {
       Tokens tokens = GetTokens(call->args[0]);
       ICHECK(!tokens.IsNested());
       if (tokens.IsLeaf()) {
diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc
index 774335f566..ccd726a6ec 100644
--- a/src/runtime/cpu_device_api.cc
+++ b/src/runtime/cpu_device_api.cc
@@ -73,6 +73,8 @@ class CPUDeviceAPI final : public DeviceAPI {
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
   void FreeWorkspace(Device dev, void* data) final;
 
+  bool SupportsDevicePointerArithmeticsOnHost() final { return true; }
+
   static CPUDeviceAPI* Global() {
     // NOTE: explicitly use new to avoid exit-time destruction of global state
     // Global state will be recycled by OS as the process exits.
diff --git a/src/runtime/cuda/cuda_device_api.cc 
b/src/runtime/cuda/cuda_device_api.cc
index 66357a1915..33908d750d 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -262,6 +262,8 @@ class CUDADeviceAPI final : public DeviceAPI {
     CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data);
   }
 
+  bool SupportsDevicePointerArithmeticsOnHost() final { return true; }
+
   static CUDADeviceAPI* Global() {
     // NOTE: explicitly use new to avoid exit-time destruction of global state
     // Global state will be recycled by OS as the process exits.
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index af1cf9d203..9fe6fba80f 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -551,6 +551,25 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data
   return ShapeTuple(out_shape);
 });
 
+TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray 
data) {
+  if (data->byte_offset == 0) {
+    return data;
+  }
+  auto* device_api = DeviceAPI::Get(data->device);
+  if (device_api->SupportsDevicePointerArithmeticsOnHost() &&
+      data->byte_offset % tvm::runtime::kAllocAlignment == 0) {
+    DLManagedTensor* dl_tensor = data.ToDLPack();
+    dl_tensor->dl_tensor.data =
+        reinterpret_cast<char*>(dl_tensor->dl_tensor.data) + 
dl_tensor->dl_tensor.byte_offset;
+    dl_tensor->dl_tensor.byte_offset = 0;
+    return NDArray::FromDLPack(dl_tensor);
+  } else {
+    auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device);
+    new_array.CopyFrom(data);
+    return new_array;
+  }
+});
+
 }  // namespace relax_vm
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/relax/test_op_view.py 
b/tests/python/relax/test_op_view.py
index 2433821c2a..0900e1be30 100644
--- a/tests/python/relax/test_op_view.py
+++ b/tests/python/relax/test_op_view.py
@@ -452,7 +452,9 @@ def test_applying_unknown_relative_byte_offset_is_legal():
     tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo)
 
 
-def test_legalize_without_any_changes_is_no_op():
+def test_legalize_is_no_op():
+    """R.memory.view is not legalized until LowerRuntimeBuiltin"""
+
     @I.ir_module
     class Before:
         @R.function
@@ -460,18 +462,13 @@ def test_legalize_without_any_changes_is_no_op():
             B = R.memory.view(A)
             return B
 
-    @I.ir_module
-    class Expected:
-        @R.function
-        def main(A: R.Tensor([4096], "float32")):
-            B = A
-            return B
+    Expected = Before
 
     After = tvm.relax.transform.LegalizeOps()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 
 
-def test_legalize_shape_change():
+def test_lower_runtime_builtin_shape_change():
     @I.ir_module
     class Before:
         @R.function
@@ -497,11 +494,11 @@ def test_legalize_shape_change():
             )
             return B
 
-    After = tvm.relax.transform.LegalizeOps()(Before)
+    After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 
 
-def test_legalize_view_shape_from_unknown():
+def test_lower_runtime_builtin_view_shape_from_unknown():
     """R.memory.view does not require the input tensor to have a known shape"""
 
     @I.ir_module
@@ -529,11 +526,11 @@ def test_legalize_view_shape_from_unknown():
             )
             return B
 
-    After = tvm.relax.transform.LegalizeOps()(Before)
+    After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 
 
-def test_legalize_dtype_change():
+def test_lower_runtime_builtin_dtype_change():
     @I.ir_module
     class Before:
         @R.function
@@ -559,11 +556,11 @@ def test_legalize_dtype_change():
             )
             return B
 
-    After = tvm.relax.transform.LegalizeOps()(Before)
+    After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 
 
-def test_legalize_byte_offset():
+def test_lower_runtime_builtin_byte_offset():
     @I.ir_module
     class Before:
         @R.function
@@ -589,11 +586,11 @@ def test_legalize_byte_offset():
             )
             return B
 
-    After = tvm.relax.transform.LegalizeOps()(Before)
+    After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 
 
-def test_legalize_view_with_multiple_updated_fields():
+def test_lower_runtime_builtin_view_with_multiple_updated_fields():
     """R.memory.view may update more than one field in the view
 
     In this test case, a 4-kilobyte buffer is provided.  The first
@@ -650,7 +647,7 @@ def test_legalize_view_with_multiple_updated_fields():
             )
             return (B, C)
 
-    After = tvm.relax.transform.LegalizeOps()(Before)
+    After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 
 
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 63f422d4cf..f9e632d348 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -185,7 +185,7 @@ def test_basic():
     tvm.ir.assert_structural_equal(mod, Expected)
     mod = relax.transform.LowerAllocTensor()(mod)
     mod = relax.transform.KillAfterLastUse()(mod)
-    mod = relax.transform.VMBuiltinLower()(mod)
+    mod = relax.transform.LowerRuntimeBuiltin()(mod)
     tvm.ir.assert_structural_equal(mod, ExpectedLowered)
 
 
@@ -1449,5 +1449,60 @@ def test_add():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_view():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main():
+            cls = Before
+            x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", 
runtime_device_index=0)
+            x1 = R.memory.view(x, [128], "float32", 0)
+            x2 = R.memory.ensure_zero_offset(x1)
+            y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", 
runtime_device_index=0)
+            cls.tir_exp(x2, y)
+            z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", 
runtime_device_index=0)
+            cls.tir_exp(y, z)
+            return z
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main() -> R.Tensor((128,), dtype="float32"):
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([1024]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor(
+                storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32")
+            )
+            x1: R.Tensor((128,), dtype="float32") = R.memory.view(
+                x, R.shape([128]), R.dtype("float32"), R.prim_value(0)
+            )
+            x2: R.Tensor((128,), dtype="float32") = 
R.memory.ensure_zero_offset(x1)
+            storage1: R.Object = R.memory.alloc_storage(
+                R.shape([512]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor(
+                storage1, R.prim_value(0), R.shape([128]), R.dtype("float32")
+            )
+            cls.tir_exp(x2, y)
+            z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([128]), R.dtype("float32"), R.prim_value(0), 
R.str("global")
+            )
+            cls.tir_exp(y, z)
+            return z
+
+    after = relax.transform.StaticPlanBlockMemory()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_vm_builtin_lower.py 
b/tests/python/relax/test_vm_builtin_lower.py
index df28db4d46..984f9f958c 100644
--- a/tests/python/relax/test_vm_builtin_lower.py
+++ b/tests/python/relax/test_vm_builtin_lower.py
@@ -57,7 +57,7 @@ def test_vm_builtin_lower_mem_alloc_storage():
             gv0 = alloc
             return gv0
 
-    After = relax.transform.VMBuiltinLower()(Before)
+    After = relax.transform.LowerRuntimeBuiltin()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 
 
@@ -79,7 +79,7 @@ def test_vm_builtin_alloc_tensor_raises_error():
             return gv0
 
     with pytest.raises(tvm.TVMError):
-        relax.transform.VMBuiltinLower()(Before)
+        relax.transform.LowerRuntimeBuiltin()(Before)
 
 
 if __name__ == "__main__":

Reply via email to