This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 41684ee56778fa8772ea131a43097457259be3d8 Author: masahi <[email protected]> AuthorDate: Mon Feb 20 16:47:56 2023 +0900 [Unity][VM] Add per-op profiling support (#14053) Adds per-op profiling support to Relax VM, in a way similar to how Relay VM is instrumented via the common profiling infra in the runtime. Profiling over RPC is supported. Example output: ``` Name Duration (us) Percent Device Count Argument Shapes conv2d1 705,779.00 51.22 hexagon0 1 float32[1, 64, 56, 56], float32[1, 64, 54, 54] conv2d 669,589.00 48.60 hexagon0 1 float32[1, 64, 56, 56], float32[1, 64, 56, 56] relu 683.00 0.05 hexagon0 1 float32[1, 64, 56, 56], float32[1, 64, 56, 56] relu1 679.00 0.05 hexagon0 1 float32[1, 64, 54, 54], float32[1, 64, 54, 54] vm.builtin.check_tensor_info 28.00 0.00 hexagon0 1 float32[1, 64, 56, 56] vm.builtin.match_shape 25.00 0.00 hexagon0 1 float32[1, 64, 56, 56] ---------- Sum 1,376,783.00 99.93 6 Total 0.00 cpu0 1 Total 1,377,809.00 hexagon0 1 Configuration ------------- Number of threads: 4 Executor: VM ``` The original PR: https://github.com/tlc-pack/relax/pull/422 --- include/tvm/runtime/relax_vm/vm.h | 5 ++ python/tvm/relax/vm.py | 33 +++++++-- src/runtime/relax_vm/executable.cc | 7 ++ src/runtime/relax_vm/vm.cc | 129 +++++++++++++++++++++++++++++++- tests/python/relax/test_vm_profiler.py | 130 +++++++++++++++++++++++++++++++++ 5 files changed, 295 insertions(+), 9 deletions(-) diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h index cfe3880904..d39de74f2d 100644 --- a/include/tvm/runtime/relax_vm/vm.h +++ b/include/tvm/runtime/relax_vm/vm.h @@ -120,6 +120,11 @@ class VirtualMachine : public runtime::ModuleNode { * \return Created VM */ static ObjectPtr<VirtualMachine> Create(); + /*! + * \brief Create an instance of VM with the profiling feature enabled. + * \return Created VM + */ + static ObjectPtr<VirtualMachine> CreateProfiler(); /*! * \brief Helper function for vm closure functions to get the context ptr * \param arg The argument value. diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 2cf1250690..0594d86f2a 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -25,6 +25,7 @@ from tvm import relax from tvm.ir.module import IRModule from tvm.runtime import Device, Module, PackedFunc, container from tvm.runtime.object import Object +from tvm.runtime.profiling import Report from tvm.tir.function import PrimFunc from . import _ffi_api from ..rpc.base import RPC_SESS_MASK @@ -63,6 +64,7 @@ class VirtualMachine(object): exec: Union[Executable, Module], device: Union[Device, List[Device]], memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, + profile: bool = False, ) -> None: """ Construct a VirtualMachine wrapper object. @@ -82,12 +84,12 @@ class VirtualMachine(object): allocator type. If memory_cfg is a dict, each device uses the allocator type specified in the dict, or pooled allocator if not specified in the dict. + + profile : Optional[bool] + Whether or not to enable profiling. """ - self.module = ( - exec.mod["vm_load_executable"]() - if isinstance(exec, Executable) - else exec["vm_load_executable"]() - ) + load_exec = "vm_profiler_load_executable" if profile else "vm_load_executable" + self.module = exec.mod[load_exec]() if isinstance(exec, Executable) else exec[load_exec]() self._invoke_closure = self.module["invoke_closure"] self._save_function = self.module["save_function"] self._set_input = self.module["set_input"] @@ -449,6 +451,27 @@ class VirtualMachine(object): f_preproc=f_preproc, ) + def profile(self, func_name: str, *args): + """Profile a function call. + Parameters + ---------- + func_name : str + The name of the function. + args: List of NDArray or other objects supported by PackedFunc. + The arguments to the function. + Returns + ------- + report: tvm.runtime.profiling.Report + The formatted profiling result, showing per-op timing measurements. + """ + cargs: List[Any] = [] + + for arg in args: + self._convert(arg, cargs) + + report_json = self.module["profile"](func_name, *cargs) + return Report.from_json(report_json) + def _vmcodegen( builder: "relax.ExecBuilder", diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index b7915d7978..2090a3b254 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -67,6 +67,13 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje vm->LoadExecutable(GetObjectPtr<Executable>(this)); *rv = Module(vm); }); + } else if (name == "vm_profiler_load_executable") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ObjectPtr<VirtualMachine> vm = VirtualMachine::CreateProfiler(); + ICHECK(sptr_to_self.get() == this); + vm->LoadExecutable(GetObjectPtr<Executable>(this)); + *rv = Module(vm); + }); } return nullptr; } diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index 3cf65faaa8..3b952c1ff5 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -23,8 +23,11 @@ #include <tvm/runtime/container/adt.h> #include <tvm/runtime/packed_func.h> +#include <tvm/runtime/profiling.h> #include <tvm/runtime/relax_vm/vm.h> +#include <optional> + namespace tvm { namespace runtime { namespace relax_vm { @@ -177,7 +180,7 @@ class VirtualMachineImpl : public VirtualMachine { void Init(const std::vector<Device>& devices, const std::vector<AllocatorType>& alloc_types) final; - PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override; VMClosure GetClosure(const String& func_name) final; @@ -315,11 +318,29 @@ class VirtualMachineImpl : public VirtualMachine { * \param curr_frame The current frame. * \param inst The call instruction. */ - inline void RunInstrCall(VMFrame* curr_frame, Instruction inst); + virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst); /*! \brief Run VM dispatch loop. */ void RunLoop(); + /*! + * \brief Retrieve the name of the function identified by the given index. + * \param idx The index into the VM executable function table. + * \return The name of the function. + */ + const std::string& GetFuncName(int idx) { return exec_->func_table[idx].name; } + + /*! + * \brief Retrieve the inputs for a function. + * \param func_name The name of the function. + * \return The function inputs. + */ + const std::vector<RegType>& GetInputsFor(const std::string& func_name) { + return inputs_[func_name]; + } + + void ClearInputsFor(const std::string& func_name) { inputs_.erase(func_name); } + private: //-------------------------------------------------------- // Internal states for execution. @@ -519,7 +540,7 @@ void VirtualMachineImpl::SetInput(std::string func_name, TVMArgs args, int offse int index = i - offset; func_args[index] = ConvertArgToDevice(args[i], devices[0]); } - inputs_.emplace(func_name, func_args); + inputs_[func_name] = func_args; } else { LOG(FATAL) << "ValueError: Unknown function: " << func_name; } @@ -706,7 +727,7 @@ void VirtualMachineImpl::InitFuncPool() { } void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { - DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << exec_->func_table[instr.func_idx].name; + DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx); // Use the call arg stack from the current frame to increase reuse // and avoid re-allocation @@ -806,6 +827,106 @@ void VirtualMachineImpl::RunLoop() { ObjectPtr<VirtualMachine> VirtualMachine::Create() { return make_object<VirtualMachineImpl>(); } +/*! + * \brief An extension of VirtualMachineImpl to support per-op profiling + * It overrides RunInstrCall to add instrumentations around it. + */ +class VirtualMachineProfiler : public VirtualMachineImpl { + public: + PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override { + if (name == "profile") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string f_name = args[0]; + VMClosure clo = this->GetClosure(f_name); + + std::vector<Device> devices; + for (auto dev : this->devices) { + if (dev.device_type > 0) { + devices.push_back(dev); + } + } + + prof_ = profiling::Profiler(devices, {}, {{String("Executor"), String("VM")}}); + + auto inputs = GetInputsFor(f_name); + + bool clear_inputs = false; + if (inputs.size() == 0) { + ICHECK(args.num_args > 1) << "No input is provided"; + TVMArgs f_args(args.values + 1, args.type_codes + 1, args.num_args - 1); + SetInput(f_name, args, 1); + inputs = GetInputsFor(f_name); + clear_inputs = true; + } else { + ICHECK_EQ(args.num_args, 1) << "Inputs are already provided by set_input."; + } + + // warmup + this->InvokeClosureInternal(clo, inputs); + + prof_->Start(); + this->InvokeClosureInternal(clo, inputs); + prof_->Stop(); + + // Return the report as json, since profiling::Report object is not supported by RPC + std::string report_json = prof_->Report()->AsJSON(); + *rv = report_json; + + prof_ = std::nullopt; // releases hardware counters + if (clear_inputs) { + // SetInput modifies the internal states of VM. Undo the change after profiling. + ClearInputsFor(f_name); + } + }); + } else { + return VirtualMachineImpl::GetFunction(name, sptr_to_self); + } + } + + protected: + void RunInstrCall(VMFrame* curr_frame, Instruction inst) override { + bool profiling = false; + if (prof_ && prof_->IsRunning()) { + auto f_name = GetFuncName(inst.func_idx); + std::optional<Device> dev; + std::vector<NDArray> arrs; + for (Index i = 0; i < inst.num_args; ++i) { + Instruction::Arg arg = inst.args[i]; + if (arg.kind() == Instruction::ArgKind::kRegister) { + auto reg = ReadRegister(curr_frame, arg.value()); + if (reg.type_code() == kTVMNDArrayHandle) { + NDArray arr = reg; + dev = arr->device; + arrs.push_back(arr); + } + } + } + + std::unordered_map<std::string, ObjectRef> metrics; + metrics["Argument Shapes"] = profiling::ShapeString(arrs); + + // If a sutiable device is found, enable profiling. + if (dev) { + profiling = true; + prof_->StartCall(f_name, *dev, metrics); + } + } + + VirtualMachineImpl::RunInstrCall(curr_frame, inst); + + if (profiling) { + prof_->StopCall(); + } + } + + private: + std::optional<profiling::Profiler> prof_; +}; + +ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler() { + return make_object<VirtualMachineProfiler>(); +} + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_vm_profiler.py b/tests/python/relax/test_vm_profiler.py new file mode 100644 index 0000000000..90737cc9c9 --- /dev/null +++ b/tests/python/relax/test_vm_profiler.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, rpc +from tvm.contrib import utils +from tvm.relax.testing import nn +from tvm.script import relax as R + + +def get_exec(data_shape): + builder = relax.BlockBuilder() + weight1_np = np.random.randn(64, 64).astype("float32") + weight2_np = np.random.randn(64, 64).astype("float32") + + with builder.function("main"): + model = nn.Sequential( + nn.Linear(data_shape[1], weight1_np.shape[0], bias=False), + nn.ReLU(), + nn.Linear(weight2_np.shape[0], weight2_np.shape[1], bias=False), + nn.ReLU(), + ) + data = nn.Placeholder(data_shape, name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + mod = builder.get() + + params = {"linear_weight": weight1_np, "linear_weight1": weight2_np} + mod = relax.transform.BindParams("main", params)(mod) + + target = "llvm" + return relax.vm.build(mod, target) + + +def test_conv2d_cpu(): + data_np = np.random.randn(1, 64).astype("float32") + ex = get_exec(data_np.shape) + + vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True) + report = vm.profile("main", tvm.nd.array(data_np)) + print(report) + + assert "Duration" in str(report) + assert "matmul" in str(report) + + +def with_rpc(ex, f, data_np): + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + ex.mod.export_library(path) + + server = rpc.Server("127.0.0.1") + remote = rpc.connect(server.host, server.port, session_timeout=10) + + remote.upload(path) + rexec = remote.load_module("vm_library.so") + + device = remote.cpu() + + vm = relax.vm.VirtualMachine(exec=rexec, device=device, profile=True) + data = tvm.nd.array(data_np, device) + + f(vm, data) + + +def test_rpc(): + data_np = np.random.randn(1, 64).astype("float32") + ex = get_exec(data_np.shape) + + def callback(vm, data): + vm.profile("main", data) + + vm.set_input("main", data) + report = vm.profile("main") + + assert "matmul" in str(report) + print(report) + + with_rpc(ex, callback, data_np) + + +def test_tuple(): + @tvm.script.ir_module + class NestedTuple: + @R.function + def main( + x: R.Tensor((16,), "float32") + ) -> R.Tuple( + R.Tuple( + R.Tensor((16,), "float32"), + R.Tuple( + R.Tensor((16,), "float32"), + ), + ), + R.Tensor((16,), "float32"), + ): + return ((x, (x,)), x) + + target = "llvm" + ex = relax.vm.build(NestedTuple, target) + + data_np = np.random.randn(16).astype("float32") + + def callback(vm, data): + report = vm.profile("main", data) + assert "vm.builtin.make_tuple" in str(report) + + with_rpc(ex, callback, data_np) + + +if __name__ == "__main__": + tvm.testing.main()
