This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 50d0128a2e [Unity][DEBUG] Add Instrument (#14302)
50d0128a2e is described below
commit 50d0128a2eca29212ccf2ca277f207e670aa69e8
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Mar 15 11:59:57 2023 -0400
[Unity][DEBUG] Add Instrument (#14302)
This PR adds an instrumentation option to the relax VM.
The instrument will be called before/after each call
instruction if specified.
We also include a testing utility that leverages uses
instrument. LibCompareVMInstrument leverages the instrument
to compare implementations on another backend.
Also updated a few places in web runtime to improve debugging.
---
include/tvm/runtime/relax_vm/vm.h | 30 +++++++
python/tvm/relax/__init__.py | 2 +-
python/tvm/relax/testing/lib_comparator.py | 128 +++++++++++++++++++++++++++++
python/tvm/runtime/relax_vm.py | 48 +++++++++++
src/runtime/relax_vm/vm.cc | 70 ++++++++++++++--
src/target/source/codegen_webgpu.cc | 26 +++++-
src/target/spirv/intrin_rule_spirv.cc | 5 ++
tests/python/relax/test_vm_instrument.py | 87 ++++++++++++++++++++
web/apps/browser/rpc_server.html | 9 +-
web/src/rpc_server.ts | 1 +
web/src/runtime.ts | 10 +++
web/src/webgpu.ts | 48 ++++++++++-
12 files changed, 445 insertions(+), 19 deletions(-)
diff --git a/include/tvm/runtime/relax_vm/vm.h
b/include/tvm/runtime/relax_vm/vm.h
index bd59106cc1..95a2080159 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -39,6 +39,16 @@ namespace tvm {
namespace runtime {
namespace relax_vm {
+/*!
+ * \brief Possible instrument actions.
+ */
+enum class VMInstrumentReturnKind : int {
+ /*! \brief Running as normal. */
+ kNoOp = 0,
+ /*! \brief Skip the following run, only valid in before. */
+ kSkipRun = 1,
+};
+
/*!
* \brief An object representing a vm closure.
*/
@@ -119,6 +129,26 @@ class VirtualMachine : public runtime::ModuleNode {
*/
virtual void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc,
TVMArgs args,
TVMRetValue* rv) = 0;
+ /*!
+ * \brief Set an instrumentation function.
+ *
+ * If instrument is present, the function will be called
+ * before/after each Call instruction.
+ *
+ * bool instrument(func, func_symbol, before_run, args...)
+ *
+ * - func: Union[VMClosure, PackedFunc], the function object.
+ * - func_symbol: string, the symbol of the function.
+ * - before_run: bool, whether it is before or after call.
+ * - ret_value: Only valid in after run, otherwise it is null.
+ * - args: the arguments being passed to call.
+ *
+ * instrument can return an int which corresponds to the action value.
+ * \sa VMInstrumentAction
+ *
+ * \param instrument The instrument function.
+ */
+ virtual void SetInstrument(PackedFunc instrument) = 0;
/*!
* \brief Create a specific instance of VM.
* \return Created VM
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index edbd848bd5..f34a3169bb 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -17,7 +17,7 @@
# pylint: disable=invalid-name, wrong-import-position
"""The Relax IR namespace containing the IR, type, operator, builder, vm,
etc."""
from tvm.runtime import relax_vm as vm
-from tvm.runtime.relax_vm import VirtualMachine
+from tvm.runtime.relax_vm import VirtualMachine, VMInstrumentReturnKind
# Expr
from .expr import (
diff --git a/python/tvm/relax/testing/lib_comparator.py
b/python/tvm/relax/testing/lib_comparator.py
new file mode 100644
index 0000000000..a9cecc69dc
--- /dev/null
+++ b/python/tvm/relax/testing/lib_comparator.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Tools to compare libraries."""
+from typing import List, Tuple, Iterable, Union
+
+import tvm
+import tvm.testing
+
+
+class LibCompareVMInstrument:
+ """Instrument class to compare libs.
+
+ This class build an instrument function that
+ pair tests an existing compiled relax vm implementation
+ and an extra module, which can sits in another backend
+ but offers a same subset of compiled TIR functions.
+
+ The instrumentation enables us to automatically
+ check and compare each ops being called in the pipeline
+ by looking up the same name in the provided mod and run testing.
+
+ Parameters
+ ----------
+ mod: runtime.Module
+ The module of interest to be validated.
+
+ device: runtime.Device
+ The device to run the target module on.
+
+ verbose: bool
+ Whether print out messages.
+
+ rtol: float
+ rtol used in validation
+
+ atol: float
+ atol used in validation
+ """
+
+ def __init__(self, mod, device, verbose=True, rtol=1e-5, atol=1e-5):
+ self.mod = mod
+ self.device = device
+ self.verbose = verbose
+ self.counter = 0
+ self.rtol = rtol
+ self.atol = atol
+
+ def compare(
+ self,
+ name: str,
+ ref_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]],
+ new_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]],
+ ret_indices: Iterable[int],
+ ):
+ """Comparison function, can be overloaded.
+
+ Parameters
+ ----------
+ name: str
+ Name of the function.
+
+ ref_args:
+ The reference arguments.
+
+ new_args:
+ The args to be passed to the comparison function.
+
+ ret_indices:
+ List of indices to validate return values.
+ """
+ my_func = self.mod.get_function(name, query_imports=True)
+ if self.verbose:
+ print(f"[{self.counter}] Validating {name} ...")
+ my_func(*new_args)
+ for rindex in ret_indices:
+ tvm.testing.assert_allclose(
+ new_args[rindex].numpy(), ref_args[rindex].numpy(),
atol=self.atol, rtol=self.rtol
+ )
+ if self.verbose:
+ print(f"[{self.counter}] Validating {name}, passed.")
+ self.counter += 1
+
+ def skip_instrument(self, func, name, before_run, ret_val, *args):
+ return False
+
+ def __call__(self, func, name, before_run, ret_val, *args):
+ if before_run:
+ return
+ if name.startswith("vm.builtin."):
+ return
+ if any(not isinstance(x, tvm.nd.NDArray) for x in args):
+ return
+ try:
+ self.mod.get_function(name, query_imports=True)
+ except AttributeError:
+ if self.verbose:
+ print(f"Cannot find {name}, skip...")
+ return
+
+ if self.skip_instrument(func, name, before_run, ret_val, *args):
+ return
+
+ new_args = []
+ # not always true, true for most ops.
+ ret_indices = (len(args) - 1,)
+ for i, arg in enumerate(args):
+ arr = tvm.nd.empty(arg.shape, device=self.device)
+ # copy from cpu since we look at different device
+ if i not in ret_indices:
+ arr.copyfrom(arg.copyto(tvm.cpu()))
+ new_args.append(arr)
+
+ self.compare(name, args, new_args, ret_indices)
diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py
index 9defcb7d80..c53882095d 100644
--- a/python/tvm/runtime/relax_vm.py
+++ b/python/tvm/runtime/relax_vm.py
@@ -17,6 +17,7 @@
# pylint: disable=invalid-name, redefined-builtin, no-else-return,
consider-using-dict-items
"""The Relax virtual machine."""
from typing import Callable, List, Optional, Union, Dict, Tuple, Any
+from enum import IntEnum
import numpy as np # type: ignore
import tvm
@@ -28,6 +29,12 @@ from tvm.runtime.profiling import Report
from ..rpc.base import RPC_SESS_MASK
+class VMInstrumentReturnKind(IntEnum):
+ NO_OP = 0
+ # skip the following call, only valid in before
+ SKIP_RUN = 1
+
+
class VirtualMachine(object):
"""Relax VM runtime."""
@@ -85,6 +92,7 @@ class VirtualMachine(object):
self._get_output_arity = self.module["get_output_arity"]
self._get_function_arity = self.module["get_function_arity"]
self._get_function_param_name = self.module["get_function_param_name"]
+ self._set_instrument = self.module["set_instrument"]
self._setup_device(device, memory_cfg)
def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device,
str]]) -> None:
@@ -329,6 +337,46 @@ class VirtualMachine(object):
return get_output_rec(func_name)
+ def set_instrument(self, instrument: tvm.runtime.PackedFunc):
+ """Set an instrumentation function.
+
+ If instrument is present, the function will be called
+ before/after each Call instruction. The function have
+ the following signature:
+
+ .. code:: python
+
+ def instrument(
+ func: Union[VMClosure, PackedFunc],
+ func_symbol: str,
+ before_run: bool,
+ ret_value: any,
+ *args) -> bool:
+ pass
+
+ The instrument takes the following parameters:
+ - func: function object to be called.
+ - func_symbol: the symbol name of the function.
+ - before_run: whether it is before or after call.
+ - ret_value: the return value of the call, only valid after run.
+ - args: the arguments being passed to call.
+
+ The instrument function can choose an integer,
+ which corresponds to action direction for the
+ following run. See VMInstrumentReturnKind for
+ more details.
+
+ Parameters
+ ----------
+ instrument: tvm.runtime.PackedFunc
+ A instrumentation function that get invoked every VM call instr.
+
+ See Also
+ --------
+ VMInstrumentReturnKind: the possible return values in VM.
+ """
+ self._set_instrument(instrument)
+
def time_evaluator(
self,
func_name,
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 9a3ce50bcc..6088833dc5 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -193,6 +193,8 @@ class VirtualMachineImpl : public VirtualMachine {
void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs
args,
TVMRetValue* rv) final;
+ void SetInstrument(PackedFunc instrument) final { this->instrument_ =
instrument; }
+
//--------------------------------------------------
// Additional support arguments functions for VM
//--------------------------------------------------
@@ -382,6 +384,8 @@ class VirtualMachineImpl : public VirtualMachine {
Index pc_{0};
/*! \brief The special return register. */
RegType return_value_;
+ /*!\ brief instrument function. */
+ PackedFunc instrument_ = nullptr;
};
void VirtualMachineImpl::LoadExecutable(ObjectPtr<Executable> exec) {
@@ -465,6 +469,21 @@ PackedFunc VirtualMachineImpl::GetFunction(const
std::string& name,
this->InvokeClosurePacked(clo, TVMArgs(args.values + 1, args.type_codes
+ 1, args.size() - 1),
rv);
});
+ } else if (name == "set_instrument") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ PackedFunc func;
+ if (args[0].type_code() != kTVMPackedFuncHandle) {
+ String func_name = args[0];
+ const PackedFunc* factory = Registry::Get(func_name);
+ ICHECK(factory != nullptr) << "Cannot find factory " << func_name;
+ TVMRetValue rv;
+ factory->CallPacked(TVMArgs(args.values + 1, args.type_codes + 1,
args.num_args - 1), &rv);
+ func = rv;
+ } else {
+ func = args[0];
+ }
+ this->SetInstrument(func);
+ });
} else if (name == "invoke_stateful") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
@@ -746,11 +765,11 @@ void VirtualMachineImpl::InitFuncPool() {
void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
DLOG(INFO) << "\n pc = " << pc_ << ", execute: " <<
GetFuncName(instr.func_idx);
-
+ int args_begin_offset = instrument_ != nullptr ? 4 : 0;
// Use the call arg stack from the current frame to increase reuse
// and avoid re-allocation
- curr_frame->call_arg_values.resize(instr.num_args);
- curr_frame->call_arg_tcodes.resize(instr.num_args);
+ curr_frame->call_arg_values.resize(args_begin_offset + instr.num_args);
+ curr_frame->call_arg_tcodes.resize(args_begin_offset + instr.num_args);
// NOTE: no changes and resize to those vector ref(otherwise can leads to
segfault)
// in the remainder part of the function.
@@ -760,22 +779,23 @@ void VirtualMachineImpl::RunInstrCall(VMFrame*
curr_frame, Instruction instr) {
runtime::TVMArgsSetter setter(values.data(), tcodes.data());
for (Index i = 0; i < instr.num_args; ++i) {
Instruction::Arg arg = instr.args[i];
+ int arg_index = args_begin_offset + i;
switch (arg.kind()) {
case Instruction::ArgKind::kRegister: {
- setter(i, ReadRegister(curr_frame, arg.value()));
+ setter(arg_index, ReadRegister(curr_frame, arg.value()));
break;
}
case Instruction::ArgKind::kImmediate: {
- setter(i, arg.value());
+ setter(arg_index, arg.value());
break;
}
case Instruction::ArgKind::kConstIdx: {
- setter(i, this->const_pool_[arg.value()]);
+ setter(arg_index, this->const_pool_[arg.value()]);
break;
}
case Instruction::ArgKind::kFuncIdx: {
ICHECK_LT(static_cast<size_t>(arg.value()), this->func_pool_.size());
- setter(i, this->func_pool_[arg.value()]);
+ setter(arg_index, this->func_pool_[arg.value()]);
break;
}
default: {
@@ -783,11 +803,43 @@ void VirtualMachineImpl::RunInstrCall(VMFrame*
curr_frame, Instruction instr) {
}
}
}
- TVMArgs args(values.data(), tcodes.data(), values.size());
+ TVMArgs args(values.data() + args_begin_offset, tcodes.data() +
args_begin_offset,
+ instr.num_args);
TVMRetValue ret;
ICHECK_LT(static_cast<size_t>(instr.func_idx), this->func_pool_.size());
- this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
+
+ if (instrument_ == nullptr) {
+ this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
+ } else {
+ // insert light-weight instrument callback
+ setter(0, func_pool_[instr.func_idx]);
+ setter(1, GetFuncName(instr.func_idx));
+ setter(2, true);
+ setter(3, nullptr);
+ TVMRetValue rv;
+ // store dtype to str since py callback cannot handle dtype atm.
+ std::vector<std::unique_ptr<std::string>> temp_dtype;
+ for (int i = 0; i < instr.num_args; ++i) {
+ if (tcodes[i + args_begin_offset] == kTVMDataType) {
+ std::string str_dtype = args[i];
+ temp_dtype.emplace_back(std::make_unique<std::string>(str_dtype));
+ setter(i + args_begin_offset, *temp_dtype.back());
+ }
+ }
+ int ret_kind = static_cast<int>(VMInstrumentReturnKind::kNoOp);
+ instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(),
values.size()), &rv);
+ if (rv.type_code() == kDLInt) {
+ ret_kind = rv;
+ }
+ if (ret_kind != static_cast<int>(VMInstrumentReturnKind::kSkipRun)) {
+ this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
+ setter(2, false);
+ setter(3, ret);
+ instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(),
values.size()), &rv);
+ }
+ }
+
// save the return value to the register
// saving to special register is a NOP
if (instr.dst < Instruction::kBeginSpecialReg) {
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index 8ba2b4a65e..ccac00a778 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -335,8 +335,30 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLIN
this->PrintExpr(EnforceU32(op->args[1]), os);
os << ')';
} else if (op->op.same_as(builtin::if_then_else())) {
- // WebGPU will insert clamping in buffer access so no need to check OOB.
- this->PrintExpr(Select(op->args[0], op->args[1], op->args[2]), os);
+ // conditional that skips eval if cond evals to false
+ std::string result = name_supply_->FreshName("condval");
+ std::string cond = PrintExpr(op->args[0]);
+ this->PrintIndent();
+ this->stream << "var " << result << " : ";
+ PrintType(op->dtype, this->stream);
+ this->stream << ";\n";
+ this->PrintIndent();
+ this->stream << "if (" << cond << ") {\n";
+ {
+ int then_scope = this->BeginScope();
+ std::string true_val = PrintExpr(op->args[1]);
+ this->PrintIndent();
+ this->stream << result << " = " << true_val << ";\n} else {\n";
+ this->EndScope(then_scope);
+ }
+ {
+ int else_scope = this->BeginScope();
+ std::string false_val = PrintExpr(op->args[2]);
+ this->PrintIndent();
+ this->stream << result << " = " << false_val << ";\n}\n";
+ this->EndScope(else_scope);
+ }
+ os << result;
} else {
CodeGenC::VisitExpr_(op, os);
}
diff --git a/src/target/spirv/intrin_rule_spirv.cc
b/src/target/spirv/intrin_rule_spirv.cc
index ac304b92b6..ffef425c0e 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -27,6 +27,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
+#include "../intrin_rule.h"
+
namespace tvm {
namespace codegen {
namespace spirv {
@@ -100,6 +102,9 @@
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Tanh>);
+
+TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
+ codegen::intrin
::DispatchFastErf);
} // namespace intrin
namespace legalize {
diff --git a/tests/python/relax/test_vm_instrument.py
b/tests/python/relax/test_vm_instrument.py
new file mode 100644
index 0000000000..8297da1b74
--- /dev/null
+++ b/tests/python/relax/test_vm_instrument.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+
+from tvm import relax
+from tvm.relax.testing import nn
+from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
+
+
+def get_exec(data_shape):
+ builder = relax.BlockBuilder()
+ weight1_np = np.random.randn(64, 64).astype("float32")
+ weight2_np = np.random.randn(64, 64).astype("float32")
+
+ with builder.function("main"):
+ model = nn.Sequential(
+ nn.Linear(data_shape[1], weight1_np.shape[0], bias=False),
+ nn.ReLU(),
+ nn.Linear(weight2_np.shape[0], weight2_np.shape[1], bias=False),
+ nn.ReLU(),
+ )
+ data = nn.Placeholder(data_shape, name="data")
+ output = model(data)
+ params = [data] + model.parameters()
+ builder.emit_func_output(output, params=params)
+
+ mod = builder.get()
+
+ params = {"linear_weight": weight1_np, "linear_weight1": weight2_np}
+ mod = relax.transform.BindParams("main", params)(mod)
+
+ target = "llvm"
+ return relax.build(mod, target)
+
+
+def test_conv2d_cpu():
+ data_np = np.random.randn(1, 64).astype("float32")
+ ex = get_exec(data_np.shape)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ hit_count = {}
+
+ def instrument(func, name, before_run, ret_val, *args):
+ if (name, before_run) not in hit_count:
+ hit_count[(name, before_run)] = 0
+ hit_count[(name, before_run)] += 1
+ assert callable(func)
+ if before_run:
+ assert ret_val is None
+ if name == "matmul":
+ return relax.VMInstrumentReturnKind.SKIP_RUN
+
+ vm.set_instrument(instrument)
+ vm["main"](tvm.nd.array(data_np))
+ assert hit_count[("matmul", True)] == 2
+ assert ("matmul", False) not in hit_count
+ assert hit_count[("relu", True)] == 2
+ assert hit_count[("relu", False)] == 2
+
+
+def test_lib_comparator():
+ data_np = np.random.randn(1, 64).astype("float32")
+ ex = get_exec(data_np.shape)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ # compare against library module
+ cmp = LibCompareVMInstrument(vm.module.imported_modules[0], tvm.cpu(),
verbose=False)
+ vm.set_instrument(cmp)
+ vm["main"](tvm.nd.array(data_np))
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html
index becc92ee71..4a7b0864e1 100644
--- a/web/apps/browser/rpc_server.html
+++ b/web/apps/browser/rpc_server.html
@@ -44,8 +44,8 @@
}
function fetchProgressCallback(report) {
- document.getElementById("progress-tracker-label").innerHTML =
report.text;
- document.getElementById("progress-tracker-progress").value =
(report.fetchedBytes / report.totalBytes) * 100;
+ document.getElementById("rpc-progress-tracker-label").innerHTML =
report.text;
+ document.getElementById("rpc-progress-tracker-progress").value =
(report.fetchedBytes / report.totalBytes) * 100;
}
function connectRPC() {
@@ -130,9 +130,8 @@
<button onclick="connectRPC()">Connect To Proxy</button>
<button onclick="clearLog()">Clear Log</button>
<div id="progress">
- <label id="gpu-tracker-label"> </label><br>
- <label id="progress-tracker-label"> </label> <br>
- <progress id="progress-tracker-progress" max="100" value="100">
</progress>
+ <label id="rpc-progress-tracker-label"> </label> <br>
+ <progress id="rpc-progress-tracker-progress" max="100" value="100">
</progress>
</div>
<div id="includeRPCPlugin"></div>
<div id="log"></div>
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
index 960c0ae18b..24acb7ece4 100644
--- a/web/src/rpc_server.ts
+++ b/web/src/rpc_server.ts
@@ -137,6 +137,7 @@ export class RPCServer {
this.globalObjects.forEach(obj => {
obj.dispose();
});
+ this.log(this.inst.runtimeStatsText());
this.inst.dispose();
}
if (this.state == RPCServerState.ReceivePacketHeader) {
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 5c2b17fc25..4550ba9837 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -1031,6 +1031,16 @@ export class Instance implements Disposable {
this.ctx.dispose();
this.lib.dispose();
}
+ /**
+ * Obtain the runtime information in readable format.
+ */
+ runtimeStatsText(): string {
+ if (this.lib.webGPUContext !== undefined) {
+ return this.lib.webGPUContext.runtimeStatsText();
+ } else {
+ return "";
+ }
+ }
/**
* Begin a new scope for tracking object disposal.
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index bc466a1543..57a9be8cf9 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -271,6 +271,7 @@ class CanvaRenderManager implements Disposable {
}
}
+
/**
* WebGPU context
* Manages all the webgpu resources here.
@@ -278,11 +279,24 @@ class CanvaRenderManager implements Disposable {
export class WebGPUContext {
device: GPUDevice;
memory: Memory;
-
- //private readBuffer:;
+ // internal data
private bufferTable: Array<GPUBuffer | undefined> = [undefined];
private bufferTableFreeId: Array<number> = [];
private canvasRenderManager?: CanvaRenderManager = undefined;
+ // flags for debugging
+ // stats of the runtime.
+ // peak allocation
+ private peakAllocatedBytes: number = 0;
+ // current allocation
+ private currAllocatedBytes: number = 0;
+ // all allocation(ignoring free)
+ private allAllocatedBytes: number = 0;
+ // shader submit counter
+ private shaderSubmitCounter: number = 0;
+ // limite number of shaders to be submitted, useful for debugging, default
to -1
+ protected debugShaderSubmitLimit: number = -1;
+ // log and sync each step
+ protected debugLogFinish: boolean = false;
constructor(memory: Memory, device: GPUDevice) {
this.memory = memory;
@@ -304,6 +318,16 @@ export class WebGPUContext {
this.canvasRenderManager = undefined;
}
+ /**
+ * Obtain the runtime information in readable format.
+ */
+ runtimeStatsText(): string {
+ let info = "peak-memory=" + Math.ceil(this.peakAllocatedBytes / (1 << 20))
+ " MB";
+ info += ", all-memory=" + Math.ceil(this.allAllocatedBytes / (1 << 20)) +
" MB";
+ info += ", shader-submissions=" + this.shaderSubmitCounter;
+ return info;
+ }
+
/**
* Draw image from data in storage buffer.
* @param ptr The GPU ptr
@@ -423,6 +447,12 @@ export class WebGPUContext {
});
const submitShader = (...args: Array<GPUPointer | number>): void => {
+ if (this.debugShaderSubmitLimit != -1 &&
+ this.shaderSubmitCounter >= this.debugShaderSubmitLimit) {
+ this.shaderSubmitCounter += 1;
+ return;
+ }
+
const commandEncoder = this.device.createCommandEncoder();
const compute = commandEncoder.beginComputePass();
compute.setPipeline(pipeline);
@@ -470,6 +500,14 @@ export class WebGPUContext {
compute.end()
const command = commandEncoder.finish();
this.device.queue.submit([command]);
+
+ if (this.debugLogFinish) {
+ const currCounter = this.shaderSubmitCounter;
+ this.device.queue.onSubmittedWorkDone().then(()=> {
+ console.log("["+ currCounter + "][Debug] finish shader" +
finfo.name);
+ });
+ }
+ this.shaderSubmitCounter += 1;
};
return submitShader;
@@ -528,6 +566,11 @@ export class WebGPUContext {
size: nbytes,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC |
GPUBufferUsage.COPY_DST,
});
+ this.currAllocatedBytes += nbytes;
+ this.allAllocatedBytes += nbytes;
+ if (this.currAllocatedBytes > this.peakAllocatedBytes) {
+ this.peakAllocatedBytes = this.currAllocatedBytes;
+ }
const ptr = this.attachToBufferTable(buffer);
return ptr;
}
@@ -538,6 +581,7 @@ export class WebGPUContext {
this.bufferTable[idx] = undefined;
assert(buffer !== undefined);
this.bufferTableFreeId.push(idx);
+ this.currAllocatedBytes -= buffer.size;
buffer.destroy();
}