This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eb4175bd3d [VM] Recycle VMFrame (#16822)
eb4175bd3d is described below
commit eb4175bd3ddc99a5d902eed30476127a0abdc1dc
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Mar 30 16:30:51 2024 -0400
[VM] Recycle VMFrame (#16822)
This PR recycles the VMFrame in VM which can help a bit
when function involves large frames.
---
src/runtime/relax_vm/vm.cc | 35 +++++++++++++++++++++++++++---
src/support/ffi_testing.cc | 54 ++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 86 insertions(+), 3 deletions(-)
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index d7f943d5f4..618e68c4fd 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -177,6 +177,20 @@ struct VMFrame {
VMFrame(Index pc, Index register_file_size)
: return_pc(pc), register_file(register_file_size),
caller_return_register(0) {}
+
+ void Clear() {
+ this->caller_return_register = 0;
+ this->call_arg_values.clear();
+ this->call_arg_tcodes.clear();
+ for (RegType& reg : register_file) {
+ reg = nullptr;
+ }
+ }
+
+ void ResetForRecycle(Index pc, Index register_file_size) {
+ this->return_pc = pc;
+ this->register_file.resize(register_file_size);
+ }
};
class VirtualMachineImpl : public VirtualMachine {
@@ -322,6 +336,8 @@ class VirtualMachineImpl : public VirtualMachine {
~FrameGuard() {
ICHECK_GT(vm->frames_.size(), 0);
vm->pc_ = vm->frames_.back()->return_pc;
+ vm->frames_.back()->Clear();
+ vm->frame_free_list_.emplace_back(std::move(vm->frames_.back()));
vm->frames_.pop_back();
}
};
@@ -335,7 +351,15 @@ class VirtualMachineImpl : public VirtualMachine {
* \return A RAII wrapper that pops the frame when going out of scope.
*/
FrameGuard PushFrame(Index ret_pc, const VMFuncInfo& vm_func) {
- return FrameGuard(this, std::make_unique<VMFrame>(ret_pc,
vm_func.register_file_size));
+ std::unique_ptr<VMFrame> new_frame;
+ if (!frame_free_list_.empty()) {
+ new_frame = std::move(frame_free_list_.back());
+ frame_free_list_.pop_back();
+ new_frame->ResetForRecycle(ret_pc, vm_func.register_file_size);
+ } else {
+ new_frame = std::make_unique<VMFrame>(ret_pc,
vm_func.register_file_size);
+ }
+ return FrameGuard(this, std::move(new_frame));
}
/*!
* \brief Write to a VM register.
@@ -343,7 +367,7 @@ class VirtualMachineImpl : public VirtualMachine {
* \param reg The register to write to.
* \param obj The object to write to.
*/
- void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) {
+ TVM_ALWAYS_INLINE void WriteRegister(VMFrame* frame, RegName reg, const
RegType& obj) {
ICHECK_LT(reg, frame->register_file.size());
frame->register_file[reg] = obj;
}
@@ -353,7 +377,7 @@ class VirtualMachineImpl : public VirtualMachine {
* \param reg The register to read from.
* \return The value of the register.
*/
- RegType ReadRegister(VMFrame* frame, RegName reg) {
+ TVM_ALWAYS_INLINE RegType ReadRegister(VMFrame* frame, RegName reg) {
if (reg < Instruction::kBeginSpecialReg) {
return frame->register_file[reg];
}
@@ -425,6 +449,11 @@ class VirtualMachineImpl : public VirtualMachine {
* \note: Use unique ptr to avoid re-allocation and copy when frames_ get
resized.
*/
std::vector<std::unique_ptr<VMFrame>> frames_;
+ /*!
+ * \brief A free list of frame
+ */
+ std::vector<std::unique_ptr<VMFrame>> frame_free_list_;
+
/*! \brief The virtual machine PC. */
Index pc_{0};
/*! \brief The special return register. */
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index 75b5a2527f..aec57a1eb2 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -189,4 +189,58 @@
TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian
TVM_REGISTER_GLOBAL("testing.AcceptsVariant")
.set_body_typed([](Variant<String, Integer> arg) -> String { return
arg->GetTypeKey(); });
+/**
+ * Simple event logger that can be used for testing purposes
+ */
+class TestingEventLogger {
+ public:
+ struct Entry {
+ String event;
+ double time_us;
+ };
+
+ TestingEventLogger() {
+ entries_.reserve(1024);
+ start_ = std::chrono::high_resolution_clock::now();
+ }
+
+ void Record(String event) {
+ auto tend = std::chrono::high_resolution_clock::now();
+ double time_us = static_cast<double>((tend - start_).count()) / 1e3;
+ entries_.emplace_back(Entry{event, time_us});
+ }
+
+ void Reset() { entries_.clear(); }
+
+ void Dump() const {
+ for (const Entry& e : entries_) {
+ LOG(INFO) << e.event << "\t" << e.time_us << " us";
+ }
+ }
+
+ static TestingEventLogger* ThreadLocal() {
+ thread_local TestingEventLogger inst;
+ return &inst;
+ }
+
+ private:
+ std::chrono::high_resolution_clock::time_point start_;
+ std::vector<Entry> entries_;
+};
+
+TVM_REGISTER_GLOBAL("testing.record_event").set_body([](TVMArgs args,
TVMRetValue* rv) {
+ if (args.size() != 0 && args[0].type_code() == kTVMStr) {
+ TestingEventLogger::ThreadLocal()->Record(args[0]);
+ } else {
+ TestingEventLogger::ThreadLocal()->Record("X");
+ }
+});
+
+TVM_REGISTER_GLOBAL("testing.reset_events").set_body([](TVMArgs args,
TVMRetValue* rv) {
+ TestingEventLogger::ThreadLocal()->Reset();
+});
+
+TVM_REGISTER_GLOBAL("testing.dump_events").set_body_typed([]() {
+ TestingEventLogger::ThreadLocal()->Dump();
+});
} // namespace tvm