This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 05e2bc3340 [Relax] Implement R.ensure_zero_offset and update memory
planning for R.view (#17145)
05e2bc3340 is described below
commit 05e2bc3340d1c0ca505e8a66bee29ffd5d294379
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Aug 6 07:13:49 2024 -0700
[Relax] Implement R.ensure_zero_offset and update memory planning for
R.view (#17145)
Previously, `R.view` was legalized to extern call to
`runtime.TVMArrayCreateView` during `LegalizeOps`. This call to extern
func can't be properly handled by `StaticBlockPlanMemory` because it
assumes the extern func does not retain the input buffer. Extern func
returning a view of the input would break the ref count of the
buffer. This PR defers the legalization of `R.view` so that it can be
explicitly handled by memory planning.
A new op `R.ensure_aligned` is added as discussed in #16955
---
include/tvm/relax/backend.h | 2 +-
include/tvm/relax/op_attr_types.h | 9 ++++
include/tvm/runtime/device_api.h | 5 ++
python/tvm/relax/op/memory/__init__.py | 2 +-
python/tvm/relax/op/memory/view.py | 17 +++++++
python/tvm/relax/pipeline.py | 2 +-
python/tvm/relax/transform/__init__.py | 9 ++--
python/tvm/relax/transform/transform.py | 17 ++++++-
...m_builtin_lower.cc => lower_runtime_builtin.cc} | 26 ++++++----
src/relax/op/memory/view.cc | 35 +++++++++++--
src/relax/op/memory/view.h | 3 ++
src/relax/transform/static_plan_block_memory.cc | 13 +++--
src/runtime/cpu_device_api.cc | 2 +
src/runtime/cuda/cuda_device_api.cc | 2 +
src/runtime/relax_vm/builtin.cc | 19 ++++++++
tests/python/relax/test_op_view.py | 31 ++++++------
.../test_transform_static_plan_block_memory.py | 57 +++++++++++++++++++++-
tests/python/relax/test_vm_builtin_lower.py | 4 +-
18 files changed, 211 insertions(+), 44 deletions(-)
diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h
index 2fb11f5a6f..e7d13c47b2 100644
--- a/include/tvm/relax/backend.h
+++ b/include/tvm/relax/backend.h
@@ -35,7 +35,7 @@ namespace transform {
*
* \return The Pass.
*/
-TVM_DLL Pass VMBuiltinLower();
+TVM_DLL Pass LowerRuntimeBuiltin();
/*!
* \brief Lower the shape expression in relax to VM shape heap and TIR
functions.
diff --git a/include/tvm/relax/op_attr_types.h
b/include/tvm/relax/op_attr_types.h
index b44c4582d8..291bee597c 100644
--- a/include/tvm/relax/op_attr_types.h
+++ b/include/tvm/relax/op_attr_types.h
@@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc<Expr(const
BlockBuilder& bb, Call ca
*/
using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const
Call& call)>;
+/*! \brief The function type of a function to lower the runtime builtin.
+ *
+ * A builtin function may be lowered to a lowered form in
`LowerRuntimeBuiltin`.
+ *
+ * \param bb The BlockBuilder context.
+ * \param call The call to be lowered.
+ */
+using FLowerBuiltin = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb,
const Call& call)>;
+
/*!
* \brief Gradient for a specific op.
*
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 14b2b84b0d..c33606d98e 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -240,6 +240,11 @@ class TVM_DLL DeviceAPI {
return device_type != kDLCPU && device_type != kDLMicroDev;
}
+ /*!
+ * \brief Whether pointer arithmetics on a device owned pointer may be
performed on the host.
+ */
+ virtual bool SupportsDevicePointerArithmeticsOnHost() { return false; }
+
protected:
/*!
* \brief copy data from one place to another
diff --git a/python/tvm/relax/op/memory/__init__.py
b/python/tvm/relax/op/memory/__init__.py
index 422c5d2e1f..1191550085 100644
--- a/python/tvm/relax/op/memory/__init__.py
+++ b/python/tvm/relax/op/memory/__init__.py
@@ -17,4 +17,4 @@
"""Relax memory primitives."""
from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
-from .view import view
+from .view import view, ensure_zero_offset
diff --git a/python/tvm/relax/op/memory/view.py
b/python/tvm/relax/op/memory/view.py
index 0c3d8a03b2..95adc78209 100644
--- a/python/tvm/relax/op/memory/view.py
+++ b/python/tvm/relax/op/memory/view.py
@@ -92,3 +92,20 @@ def view(
relative_byte_offset = _normalize(relative_byte_offset, PrimValue)
return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type:
ignore
+
+
+def ensure_zero_offset(data: Expr) -> Expr:
+ """
+ Ensure the tensor has elem_offset == 0. A copy will be made if necessary.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input tensor
+
+ Results
+ -------
+ result : relax.Expr
+ The tensor with elem_offset == 0
+ """
+ return _ffi_api.ensure_zero_offset(data) # type: ignore
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index d068f800d0..38242ff4d2 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -92,7 +92,7 @@ def default_build_pipeline():
transform.RewriteCUDAGraph(),
transform.LowerAllocTensor(),
transform.KillAfterLastUse(),
- transform.VMBuiltinLower(),
+ transform.LowerRuntimeBuiltin(),
transform.ComputePrimValue(),
transform.VMShapeLower(),
transform.AttachGlobalSymbol(),
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 5789e2fcf2..1ce864651c 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -55,6 +55,7 @@ from .transform import (
LegalizeOps,
LiftTransformParams,
LowerAllocTensor,
+ LowerRuntimeBuiltin,
MergeCompositeFunctions,
MetaScheduleApplyDatabase,
MetaScheduleTuneIRMod,
@@ -64,8 +65,8 @@ from .transform import (
PatternCheckContext,
RealizeVDevice,
RemovePurityChecking,
- RemoveUnusedParameters,
RemoveUnusedOutputs,
+ RemoveUnusedParameters,
ReorderPermuteDimsAfterConcat,
ReorderTakeAfterMatmul,
RewriteCUDAGraph,
@@ -84,14 +85,14 @@ from .transform import (
function_pass,
)
+from .attach_external_modules import AttachExternModules
+from .fast_math import FastMathTransform
+from .fuse_transpose_matmul import FuseTransposeMatmul
from .ipc_allreduce_rewrite import IPCAllReduceRewrite
from .lazy_transform_params import LazyTransformParams
from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage
from .optimize_layout_transform import OptimizeLayoutTransform
from .remove_redundant_reshape import RemoveRedundantReshape
-from .fast_math import FastMathTransform
-from .fuse_transpose_matmul import FuseTransposeMatmul
-from .attach_external_modules import AttachExternModules
# Import to register the legalization functions.
from . import legalize_ops, tuning_api
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 3528b4429e..2546284625 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -19,6 +19,7 @@
import functools
import inspect
import types
+import warnings
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple,
Union
import numpy as np # type: ignore
@@ -586,6 +587,16 @@ def ComputePrimValue() -> tvm.ir.transform.Pass:
return _ffi_api.ComputePrimValue() # type: ignore
+def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass:
+ """Lowering generic intrinsic to VM intrinsics.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.LowerRuntimeBuiltin() # type: ignore
+
+
def VMBuiltinLower() -> tvm.ir.transform.Pass:
"""Lowering generic intrinsic to VM intrinsics.
@@ -593,7 +604,11 @@ def VMBuiltinLower() -> tvm.ir.transform.Pass:
-------
ret: tvm.ir.transform.Pass
"""
- return _ffi_api.VMBuiltinLower() # type: ignore
+ warnings.warn(
+ "tvm.relax.transform.VMBuiltinLower has been renamed to
'LowerRuntimeBuiltin'. "
+ "This wrapper is for backwards compatibility, and will be removed in a
later update."
+ )
+ return _ffi_api.LowerRuntimeBuiltin() # type: ignore
def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc
b/src/relax/backend/vm/lower_runtime_builtin.cc
similarity index 90%
rename from src/relax/backend/vm/vm_builtin_lower.cc
rename to src/relax/backend/vm/lower_runtime_builtin.cc
index 887998d004..a3867ae924 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/lower_runtime_builtin.cc
@@ -17,13 +17,14 @@
* under the License.
*/
/*!
- * \file src/relax/backend/vm/vm_builtin_lower.cc
+ * \file src/relax/backend/vm/lower_runtime_builtin.cc
* \brief Lowers most builtin functions and packed calls.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/op.h>
@@ -33,11 +34,12 @@ namespace relax {
// This pass lowers most ops to VM specific builtins.
// TODO(relax-team): revisit after PrimValue.
-class VMBuiltinLowerMutator : public ExprMutator {
+class LowerRuntimeBuiltinMutator : public ExprMutator {
public:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call_node) final {
+ static const auto& lower_builtin_fmap =
Op::GetAttrMap<FLowerBuiltin>("FLowerBuiltin");
// post-order mutation
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
@@ -64,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator {
return MakeMemAllocTensor(call);
} else if (call->op == mem_kill_storage_op_ || call->op ==
mem_kill_tensor_op_) {
return MakeMemKillObject(call);
- } else {
- return call;
+ } else if (const auto* op_node = call->op.as<OpNode>()) {
+ Op op = GetRef<Op>(op_node);
+ if (lower_builtin_fmap.count(op)) {
+ return lower_builtin_fmap[op](builder_, call);
+ }
}
+ return call;
}
Expr MakeMemAllocStorage(const Call& call) {
@@ -210,17 +216,19 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};
-Expr VMBuiltinLower(const Expr& e) { return
VMBuiltinLowerMutator().VisitExpr(e); }
+Expr LowerRuntimeBuiltin(const Expr& e) { return
LowerRuntimeBuiltinMutator().VisitExpr(e); }
namespace transform {
-Pass VMBuiltinLower() {
+Pass LowerRuntimeBuiltin() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
- [=](Function f, IRModule m, PassContext pc) { return
Downcast<Function>(VMBuiltinLower(f)); };
- return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(LowerRuntimeBuiltin(f));
+ };
+ return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {});
}
-TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
+TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin);
} // namespace transform
} // namespace relax
diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc
index e7634c7edf..21a72f6200 100644
--- a/src/relax/op/memory/view.cc
+++ b/src/relax/op/memory/view.cc
@@ -291,7 +291,7 @@ StructInfo InferStructInfoView(const Call& call, const
BlockBuilder& ctx) {
TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView);
-Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
+Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) {
Expr data = call->args[0];
Expr shape = call->args[1];
Expr dtype = call->args[2];
@@ -352,8 +352,37 @@ TVM_REGISTER_OP("relax.memory.view")
"The view's byte offset, relative to the input tensor's byte
offset.")
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoView)
- .set_attr<FLegalize>("FLegalize", LegalizeView)
- .set_attr<Bool>("FPurity", Bool(true));
+ .set_attr<Bool>("FPurity", Bool(true))
+ .set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinView);
+
+Expr ensure_zero_offset(const Expr& x) {
+ static const Op& op = Op::Get("relax.memory.ensure_zero_offset");
+ return Call(op, {x});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset);
+
+StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const
BlockBuilder& ctx) {
+ if (call->args.size() != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Operator " << call->op << " should receive 1
argument, "
+ << "but received " << call->args);
+ }
+ return GetStructInfo(call->args[0]);
+}
+
+Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) {
+ const ExternFunc
builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"};
+ return Call(builtin_ensure_zero_offset_, call->args, Attrs(),
{GetStructInfo(call)});
+}
+
+TVM_REGISTER_OP("relax.memory.ensure_zero_offset")
+ .set_num_inputs(1)
+ .add_argument("x", "Tensor", "The input tensor.")
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoEnsureZeroOffset)
+ .set_attr<Bool>("FPurity", Bool(true))
+ .set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinEnsureZeroOffset);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/memory/view.h b/src/relax/op/memory/view.h
index bc8002fa5b..77ec7e9833 100644
--- a/src/relax/op/memory/view.h
+++ b/src/relax/op/memory/view.h
@@ -32,6 +32,9 @@ namespace relax {
/*! \brief View a tensor with different properties. */
Expr view(Expr x, Optional<Expr> shape, Optional<Expr> dtype, Optional<Expr>
relative_byte_offset);
+/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if
necessary. */
+Expr ensure_aligned(const Expr& x);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index 2b16d86509..74200526b6 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -286,8 +286,13 @@ class TokenAllocator1D {
std::vector<StorageToken> full_pool_;
};
-/*! \brief Check if the input op is "relax.reshape". */
-bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
+/*! \brief Check if the input op is a memory op that may return the same
buffer. */
+bool IsInplaceMemoryOp(const Expr& op) {
+ static const Op& reshape_op = Op::Get("relax.reshape");
+ static const Op& view_op = Op::Get("relax.memory.view");
+ static const Op& ensure_zero_offset_op =
Op::Get("relax.memory.ensure_zero_offset");
+ return op.same_as(reshape_op) || op.same_as(view_op) ||
op.same_as(ensure_zero_offset_op);
+}
/*! \brief The base class for the storage allocation visitor. */
class StorageAllocatorBaseVisitor : public ExprVisitor {
@@ -498,7 +503,7 @@ class StorageAllocatorInit : public
StorageAllocatorBaseVisitor {
// Create a storage token for builtin alloc_tensor.
this->CreateToken(call);
return;
- } else if (IsReshape(call->op)) {
+ } else if (IsInplaceMemoryOp(call->op)) {
// Reuse the input's token for builtin reshape.
SetTokens(call, GetTokens(call->args[0]));
return;
@@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor
{
block_tokens.push_back(new_token.get());
}
return;
- } else if (IsReshape(call->op)) {
+ } else if (IsInplaceMemoryOp(call->op)) {
Tokens tokens = GetTokens(call->args[0]);
ICHECK(!tokens.IsNested());
if (tokens.IsLeaf()) {
diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc
index 774335f566..ccd726a6ec 100644
--- a/src/runtime/cpu_device_api.cc
+++ b/src/runtime/cpu_device_api.cc
@@ -73,6 +73,8 @@ class CPUDeviceAPI final : public DeviceAPI {
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;
+ bool SupportsDevicePointerArithmeticsOnHost() final { return true; }
+
static CPUDeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
diff --git a/src/runtime/cuda/cuda_device_api.cc
b/src/runtime/cuda/cuda_device_api.cc
index 66357a1915..33908d750d 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -262,6 +262,8 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data);
}
+ bool SupportsDevicePointerArithmeticsOnHost() final { return true; }
+
static CUDADeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index af1cf9d203..9fe6fba80f 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -551,6 +551,25 @@
TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data
return ShapeTuple(out_shape);
});
+TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray
data) {
+ if (data->byte_offset == 0) {
+ return data;
+ }
+ auto* device_api = DeviceAPI::Get(data->device);
+ if (device_api->SupportsDevicePointerArithmeticsOnHost() &&
+ data->byte_offset % tvm::runtime::kAllocAlignment == 0) {
+ DLManagedTensor* dl_tensor = data.ToDLPack();
+ dl_tensor->dl_tensor.data =
+ reinterpret_cast<char*>(dl_tensor->dl_tensor.data) +
dl_tensor->dl_tensor.byte_offset;
+ dl_tensor->dl_tensor.byte_offset = 0;
+ return NDArray::FromDLPack(dl_tensor);
+ } else {
+ auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device);
+ new_array.CopyFrom(data);
+ return new_array;
+ }
+});
+
} // namespace relax_vm
} // namespace runtime
} // namespace tvm
diff --git a/tests/python/relax/test_op_view.py
b/tests/python/relax/test_op_view.py
index 2433821c2a..0900e1be30 100644
--- a/tests/python/relax/test_op_view.py
+++ b/tests/python/relax/test_op_view.py
@@ -452,7 +452,9 @@ def test_applying_unknown_relative_byte_offset_is_legal():
tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo)
-def test_legalize_without_any_changes_is_no_op():
+def test_legalize_is_no_op():
+ """R.memory.view is not legalized until LowerRuntimeBuiltin"""
+
@I.ir_module
class Before:
@R.function
@@ -460,18 +462,13 @@ def test_legalize_without_any_changes_is_no_op():
B = R.memory.view(A)
return B
- @I.ir_module
- class Expected:
- @R.function
- def main(A: R.Tensor([4096], "float32")):
- B = A
- return B
+ Expected = Before
After = tvm.relax.transform.LegalizeOps()(Before)
tvm.ir.assert_structural_equal(Expected, After)
-def test_legalize_shape_change():
+def test_lower_runtime_builtin_shape_change():
@I.ir_module
class Before:
@R.function
@@ -497,11 +494,11 @@ def test_legalize_shape_change():
)
return B
- After = tvm.relax.transform.LegalizeOps()(Before)
+ After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
tvm.ir.assert_structural_equal(Expected, After)
-def test_legalize_view_shape_from_unknown():
+def test_lower_runtime_builtin_view_shape_from_unknown():
"""R.memory.view does not require the input tensor to have a known shape"""
@I.ir_module
@@ -529,11 +526,11 @@ def test_legalize_view_shape_from_unknown():
)
return B
- After = tvm.relax.transform.LegalizeOps()(Before)
+ After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
tvm.ir.assert_structural_equal(Expected, After)
-def test_legalize_dtype_change():
+def test_lower_runtime_builtin_dtype_change():
@I.ir_module
class Before:
@R.function
@@ -559,11 +556,11 @@ def test_legalize_dtype_change():
)
return B
- After = tvm.relax.transform.LegalizeOps()(Before)
+ After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
tvm.ir.assert_structural_equal(Expected, After)
-def test_legalize_byte_offset():
+def test_lower_runtime_builtin_byte_offset():
@I.ir_module
class Before:
@R.function
@@ -589,11 +586,11 @@ def test_legalize_byte_offset():
)
return B
- After = tvm.relax.transform.LegalizeOps()(Before)
+ After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
tvm.ir.assert_structural_equal(Expected, After)
-def test_legalize_view_with_multiple_updated_fields():
+def test_lower_runtime_builtin_view_with_multiple_updated_fields():
"""R.memory.view may update more than one field in the view
In this test case, a 4-kilobyte buffer is provided. The first
@@ -650,7 +647,7 @@ def test_legalize_view_with_multiple_updated_fields():
)
return (B, C)
- After = tvm.relax.transform.LegalizeOps()(Before)
+ After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
tvm.ir.assert_structural_equal(Expected, After)
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 63f422d4cf..f9e632d348 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -185,7 +185,7 @@ def test_basic():
tvm.ir.assert_structural_equal(mod, Expected)
mod = relax.transform.LowerAllocTensor()(mod)
mod = relax.transform.KillAfterLastUse()(mod)
- mod = relax.transform.VMBuiltinLower()(mod)
+ mod = relax.transform.LowerRuntimeBuiltin()(mod)
tvm.ir.assert_structural_equal(mod, ExpectedLowered)
@@ -1449,5 +1449,60 @@ def test_add():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_view():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main():
+ cls = Before
+ x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32",
runtime_device_index=0)
+ x1 = R.memory.view(x, [128], "float32", 0)
+ x2 = R.memory.ensure_zero_offset(x1)
+ y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32",
runtime_device_index=0)
+ cls.tir_exp(x2, y)
+ z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32",
runtime_device_index=0)
+ cls.tir_exp(y, z)
+ return z
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main() -> R.Tensor((128,), dtype="float32"):
+ cls = Expected
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([1024]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor(
+ storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32")
+ )
+ x1: R.Tensor((128,), dtype="float32") = R.memory.view(
+ x, R.shape([128]), R.dtype("float32"), R.prim_value(0)
+ )
+ x2: R.Tensor((128,), dtype="float32") =
R.memory.ensure_zero_offset(x1)
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([512]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor(
+ storage1, R.prim_value(0), R.shape([128]), R.dtype("float32")
+ )
+ cls.tir_exp(x2, y)
+ z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([128]), R.dtype("float32"), R.prim_value(0),
R.str("global")
+ )
+ cls.tir_exp(y, z)
+ return z
+
+ after = relax.transform.StaticPlanBlockMemory()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_vm_builtin_lower.py
b/tests/python/relax/test_vm_builtin_lower.py
index df28db4d46..984f9f958c 100644
--- a/tests/python/relax/test_vm_builtin_lower.py
+++ b/tests/python/relax/test_vm_builtin_lower.py
@@ -57,7 +57,7 @@ def test_vm_builtin_lower_mem_alloc_storage():
gv0 = alloc
return gv0
- After = relax.transform.VMBuiltinLower()(Before)
+ After = relax.transform.LowerRuntimeBuiltin()(Before)
tvm.ir.assert_structural_equal(Expected, After)
@@ -79,7 +79,7 @@ def test_vm_builtin_alloc_tensor_raises_error():
return gv0
with pytest.raises(tvm.TVMError):
- relax.transform.VMBuiltinLower()(Before)
+ relax.transform.LowerRuntimeBuiltin()(Before)
if __name__ == "__main__":