This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 50d0128a2e [Unity][DEBUG] Add Instrument (#14302)
50d0128a2e is described below

commit 50d0128a2eca29212ccf2ca277f207e670aa69e8
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Mar 15 11:59:57 2023 -0400

    [Unity][DEBUG] Add Instrument (#14302)
    
    This PR adds an instrumentation option to the relax VM.
    The instrument will be called before/after each call
    instruction if specified.
    
    We also include a testing utility that leverages uses
    instrument. LibCompareVMInstrument leverages the instrument
    to compare implementations on another backend.
    
    Also updated a few places in web runtime to improve debugging.
---
 include/tvm/runtime/relax_vm/vm.h          |  30 +++++++
 python/tvm/relax/__init__.py               |   2 +-
 python/tvm/relax/testing/lib_comparator.py | 128 +++++++++++++++++++++++++++++
 python/tvm/runtime/relax_vm.py             |  48 +++++++++++
 src/runtime/relax_vm/vm.cc                 |  70 ++++++++++++++--
 src/target/source/codegen_webgpu.cc        |  26 +++++-
 src/target/spirv/intrin_rule_spirv.cc      |   5 ++
 tests/python/relax/test_vm_instrument.py   |  87 ++++++++++++++++++++
 web/apps/browser/rpc_server.html           |   9 +-
 web/src/rpc_server.ts                      |   1 +
 web/src/runtime.ts                         |  10 +++
 web/src/webgpu.ts                          |  48 ++++++++++-
 12 files changed, 445 insertions(+), 19 deletions(-)

diff --git a/include/tvm/runtime/relax_vm/vm.h 
b/include/tvm/runtime/relax_vm/vm.h
index bd59106cc1..95a2080159 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -39,6 +39,16 @@ namespace tvm {
 namespace runtime {
 namespace relax_vm {
 
+/*!
+ * \brief Possible instrument actions.
+ */
+enum class VMInstrumentReturnKind : int {
+  /*! \brief Running as normal. */
+  kNoOp = 0,
+  /*! \brief Skip the following run, only valid in before. */
+  kSkipRun = 1,
+};
+
 /*!
  * \brief An object representing a vm closure.
  */
@@ -119,6 +129,26 @@ class VirtualMachine : public runtime::ModuleNode {
    */
   virtual void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, 
TVMArgs args,
                                    TVMRetValue* rv) = 0;
+  /*!
+   * \brief Set an instrumentation function.
+   *
+   * If instrument is present, the function will be called
+   * before/after each Call instruction.
+   *
+   * bool instrument(func, func_symbol, before_run, args...)
+   *
+   * - func: Union[VMClosure, PackedFunc], the function object.
+   * - func_symbol: string, the symbol of the function.
+   * - before_run: bool, whether it is before or after call.
+   * - ret_value: Only valid in after run, otherwise it is null.
+   * - args: the arguments being passed to call.
+   *
+   * instrument can return an int which corresponds to the action value.
+   * \sa VMInstrumentAction
+   *
+   * \param instrument The instrument function.
+   */
+  virtual void SetInstrument(PackedFunc instrument) = 0;
   /*!
    * \brief Create a specific instance of VM.
    * \return Created VM
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index edbd848bd5..f34a3169bb 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -17,7 +17,7 @@
 # pylint: disable=invalid-name, wrong-import-position
 """The Relax IR namespace containing the IR, type, operator, builder, vm, 
etc."""
 from tvm.runtime import relax_vm as vm
-from tvm.runtime.relax_vm import VirtualMachine
+from tvm.runtime.relax_vm import VirtualMachine, VMInstrumentReturnKind
 
 # Expr
 from .expr import (
diff --git a/python/tvm/relax/testing/lib_comparator.py 
b/python/tvm/relax/testing/lib_comparator.py
new file mode 100644
index 0000000000..a9cecc69dc
--- /dev/null
+++ b/python/tvm/relax/testing/lib_comparator.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Tools to compare libraries."""
+from typing import List, Tuple, Iterable, Union
+
+import tvm
+import tvm.testing
+
+
+class LibCompareVMInstrument:
+    """Instrument class to compare libs.
+
+    This class build an instrument function that
+    pair tests an existing compiled relax vm implementation
+    and an extra module, which can sits in another backend
+    but offers a same subset of compiled TIR functions.
+
+    The instrumentation enables us to automatically
+    check and compare each ops being called in the pipeline
+    by looking up the same name in the provided mod and run testing.
+
+    Parameters
+    ----------
+    mod: runtime.Module
+        The module of interest to be validated.
+
+    device: runtime.Device
+        The device to run the target module on.
+
+    verbose: bool
+        Whether print out messages.
+
+    rtol: float
+        rtol used in validation
+
+    atol: float
+        atol used in validation
+    """
+
+    def __init__(self, mod, device, verbose=True, rtol=1e-5, atol=1e-5):
+        self.mod = mod
+        self.device = device
+        self.verbose = verbose
+        self.counter = 0
+        self.rtol = rtol
+        self.atol = atol
+
+    def compare(
+        self,
+        name: str,
+        ref_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]],
+        new_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]],
+        ret_indices: Iterable[int],
+    ):
+        """Comparison function, can be overloaded.
+
+        Parameters
+        ----------
+        name: str
+            Name of the function.
+
+        ref_args:
+            The reference arguments.
+
+        new_args:
+            The args to be passed to the comparison function.
+
+        ret_indices:
+            List of indices to validate return values.
+        """
+        my_func = self.mod.get_function(name, query_imports=True)
+        if self.verbose:
+            print(f"[{self.counter}] Validating {name} ...")
+        my_func(*new_args)
+        for rindex in ret_indices:
+            tvm.testing.assert_allclose(
+                new_args[rindex].numpy(), ref_args[rindex].numpy(), 
atol=self.atol, rtol=self.rtol
+            )
+        if self.verbose:
+            print(f"[{self.counter}] Validating {name}, passed.")
+        self.counter += 1
+
+    def skip_instrument(self, func, name, before_run, ret_val, *args):
+        return False
+
+    def __call__(self, func, name, before_run, ret_val, *args):
+        if before_run:
+            return
+        if name.startswith("vm.builtin."):
+            return
+        if any(not isinstance(x, tvm.nd.NDArray) for x in args):
+            return
+        try:
+            self.mod.get_function(name, query_imports=True)
+        except AttributeError:
+            if self.verbose:
+                print(f"Cannot find {name}, skip...")
+            return
+
+        if self.skip_instrument(func, name, before_run, ret_val, *args):
+            return
+
+        new_args = []
+        # not always true, true for most ops.
+        ret_indices = (len(args) - 1,)
+        for i, arg in enumerate(args):
+            arr = tvm.nd.empty(arg.shape, device=self.device)
+            # copy from cpu since we look at different device
+            if i not in ret_indices:
+                arr.copyfrom(arg.copyto(tvm.cpu()))
+            new_args.append(arr)
+
+        self.compare(name, args, new_args, ret_indices)
diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py
index 9defcb7d80..c53882095d 100644
--- a/python/tvm/runtime/relax_vm.py
+++ b/python/tvm/runtime/relax_vm.py
@@ -17,6 +17,7 @@
 # pylint: disable=invalid-name, redefined-builtin, no-else-return, 
consider-using-dict-items
 """The Relax virtual machine."""
 from typing import Callable, List, Optional, Union, Dict, Tuple, Any
+from enum import IntEnum
 import numpy as np  # type: ignore
 
 import tvm
@@ -28,6 +29,12 @@ from tvm.runtime.profiling import Report
 from ..rpc.base import RPC_SESS_MASK
 
 
+class VMInstrumentReturnKind(IntEnum):
+    NO_OP = 0
+    # skip the following call, only valid in before
+    SKIP_RUN = 1
+
+
 class VirtualMachine(object):
     """Relax VM runtime."""
 
@@ -85,6 +92,7 @@ class VirtualMachine(object):
         self._get_output_arity = self.module["get_output_arity"]
         self._get_function_arity = self.module["get_function_arity"]
         self._get_function_param_name = self.module["get_function_param_name"]
+        self._set_instrument = self.module["set_instrument"]
         self._setup_device(device, memory_cfg)
 
     def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, 
str]]) -> None:
@@ -329,6 +337,46 @@ class VirtualMachine(object):
 
         return get_output_rec(func_name)
 
+    def set_instrument(self, instrument: tvm.runtime.PackedFunc):
+        """Set an instrumentation function.
+
+        If instrument is present, the function will be called
+        before/after each Call instruction. The function have
+        the following signature:
+
+        .. code:: python
+
+            def instrument(
+                func: Union[VMClosure, PackedFunc],
+                func_symbol: str,
+                before_run: bool,
+                ret_value: any,
+                *args) -> bool:
+                pass
+
+        The instrument takes the following parameters:
+        - func: function object to be called.
+        - func_symbol: the symbol name of the function.
+        - before_run: whether it is before or after call.
+        - ret_value: the return value of the call, only valid after run.
+        - args: the arguments being passed to call.
+
+        The instrument function can choose an integer,
+        which corresponds to action direction for the
+        following run. See VMInstrumentReturnKind for
+        more details.
+
+        Parameters
+        ----------
+        instrument: tvm.runtime.PackedFunc
+            A instrumentation function that get invoked every VM call instr.
+
+        See Also
+        --------
+        VMInstrumentReturnKind: the possible return values in VM.
+        """
+        self._set_instrument(instrument)
+
     def time_evaluator(
         self,
         func_name,
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 9a3ce50bcc..6088833dc5 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -193,6 +193,8 @@ class VirtualMachineImpl : public VirtualMachine {
   void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs 
args,
                            TVMRetValue* rv) final;
 
+  void SetInstrument(PackedFunc instrument) final { this->instrument_ = 
instrument; }
+
   //--------------------------------------------------
   // Additional support arguments functions for VM
   //--------------------------------------------------
@@ -382,6 +384,8 @@ class VirtualMachineImpl : public VirtualMachine {
   Index pc_{0};
   /*! \brief The special return register. */
   RegType return_value_;
+  /*!\ brief instrument function. */
+  PackedFunc instrument_ = nullptr;
 };
 
 void VirtualMachineImpl::LoadExecutable(ObjectPtr<Executable> exec) {
@@ -465,6 +469,21 @@ PackedFunc VirtualMachineImpl::GetFunction(const 
std::string& name,
       this->InvokeClosurePacked(clo, TVMArgs(args.values + 1, args.type_codes 
+ 1, args.size() - 1),
                                 rv);
     });
+  } else if (name == "set_instrument") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      PackedFunc func;
+      if (args[0].type_code() != kTVMPackedFuncHandle) {
+        String func_name = args[0];
+        const PackedFunc* factory = Registry::Get(func_name);
+        ICHECK(factory != nullptr) << "Cannot find factory " << func_name;
+        TVMRetValue rv;
+        factory->CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, 
args.num_args - 1), &rv);
+        func = rv;
+      } else {
+        func = args[0];
+      }
+      this->SetInstrument(func);
+    });
   } else if (name == "invoke_stateful") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       std::string func_name = args[0];
@@ -746,11 +765,11 @@ void VirtualMachineImpl::InitFuncPool() {
 
 void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
   DLOG(INFO) << "\n  pc = " << pc_ << ", execute: " << 
GetFuncName(instr.func_idx);
-
+  int args_begin_offset = instrument_ != nullptr ? 4 : 0;
   // Use the call arg stack from the current frame to increase reuse
   // and avoid re-allocation
-  curr_frame->call_arg_values.resize(instr.num_args);
-  curr_frame->call_arg_tcodes.resize(instr.num_args);
+  curr_frame->call_arg_values.resize(args_begin_offset + instr.num_args);
+  curr_frame->call_arg_tcodes.resize(args_begin_offset + instr.num_args);
 
   // NOTE: no changes and resize to those vector ref(otherwise can leads to 
segfault)
   //       in the remainder part of the function.
@@ -760,22 +779,23 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* 
curr_frame, Instruction instr) {
   runtime::TVMArgsSetter setter(values.data(), tcodes.data());
   for (Index i = 0; i < instr.num_args; ++i) {
     Instruction::Arg arg = instr.args[i];
+    int arg_index = args_begin_offset + i;
     switch (arg.kind()) {
       case Instruction::ArgKind::kRegister: {
-        setter(i, ReadRegister(curr_frame, arg.value()));
+        setter(arg_index, ReadRegister(curr_frame, arg.value()));
         break;
       }
       case Instruction::ArgKind::kImmediate: {
-        setter(i, arg.value());
+        setter(arg_index, arg.value());
         break;
       }
       case Instruction::ArgKind::kConstIdx: {
-        setter(i, this->const_pool_[arg.value()]);
+        setter(arg_index, this->const_pool_[arg.value()]);
         break;
       }
       case Instruction::ArgKind::kFuncIdx: {
         ICHECK_LT(static_cast<size_t>(arg.value()), this->func_pool_.size());
-        setter(i, this->func_pool_[arg.value()]);
+        setter(arg_index, this->func_pool_[arg.value()]);
         break;
       }
       default: {
@@ -783,11 +803,43 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* 
curr_frame, Instruction instr) {
       }
     }
   }
-  TVMArgs args(values.data(), tcodes.data(), values.size());
+  TVMArgs args(values.data() + args_begin_offset, tcodes.data() + 
args_begin_offset,
+               instr.num_args);
   TVMRetValue ret;
 
   ICHECK_LT(static_cast<size_t>(instr.func_idx), this->func_pool_.size());
-  this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
+
+  if (instrument_ == nullptr) {
+    this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
+  } else {
+    // insert light-weight instrument callback
+    setter(0, func_pool_[instr.func_idx]);
+    setter(1, GetFuncName(instr.func_idx));
+    setter(2, true);
+    setter(3, nullptr);
+    TVMRetValue rv;
+    // store dtype to str since py callback cannot handle dtype atm.
+    std::vector<std::unique_ptr<std::string>> temp_dtype;
+    for (int i = 0; i < instr.num_args; ++i) {
+      if (tcodes[i + args_begin_offset] == kTVMDataType) {
+        std::string str_dtype = args[i];
+        temp_dtype.emplace_back(std::make_unique<std::string>(str_dtype));
+        setter(i + args_begin_offset, *temp_dtype.back());
+      }
+    }
+    int ret_kind = static_cast<int>(VMInstrumentReturnKind::kNoOp);
+    instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), 
values.size()), &rv);
+    if (rv.type_code() == kDLInt) {
+      ret_kind = rv;
+    }
+    if (ret_kind != static_cast<int>(VMInstrumentReturnKind::kSkipRun)) {
+      this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
+      setter(2, false);
+      setter(3, ret);
+      instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), 
values.size()), &rv);
+    }
+  }
+
   // save the return value to the register
   // saving to special register is a NOP
   if (instr.dst < Instruction::kBeginSpecialReg) {
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index 8ba2b4a65e..ccac00a778 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -335,8 +335,30 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLIN
     this->PrintExpr(EnforceU32(op->args[1]), os);
     os << ')';
   } else if (op->op.same_as(builtin::if_then_else())) {
-    // WebGPU will insert clamping in buffer access so no need to check OOB.
-    this->PrintExpr(Select(op->args[0], op->args[1], op->args[2]), os);
+    // conditional that skips eval if cond evals to false
+    std::string result = name_supply_->FreshName("condval");
+    std::string cond = PrintExpr(op->args[0]);
+    this->PrintIndent();
+    this->stream << "var " << result << " : ";
+    PrintType(op->dtype, this->stream);
+    this->stream << ";\n";
+    this->PrintIndent();
+    this->stream << "if (" << cond << ") {\n";
+    {
+      int then_scope = this->BeginScope();
+      std::string true_val = PrintExpr(op->args[1]);
+      this->PrintIndent();
+      this->stream << result << " = " << true_val << ";\n} else {\n";
+      this->EndScope(then_scope);
+    }
+    {
+      int else_scope = this->BeginScope();
+      std::string false_val = PrintExpr(op->args[2]);
+      this->PrintIndent();
+      this->stream << result << " = " << false_val << ";\n}\n";
+      this->EndScope(else_scope);
+    }
+    os << result;
   } else {
     CodeGenC::VisitExpr_(op, os);
   }
diff --git a/src/target/spirv/intrin_rule_spirv.cc 
b/src/target/spirv/intrin_rule_spirv.cc
index ac304b92b6..ffef425c0e 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -27,6 +27,8 @@
 #include <tvm/tir/op.h>
 #include <tvm/tir/op_attr_types.h>
 
+#include "../intrin_rule.h"
+
 namespace tvm {
 namespace codegen {
 namespace spirv {
@@ -100,6 +102,9 @@ 
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
 
 TVM_REGISTER_OP("tir.tanh")
     .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Tanh>);
+
+TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
+                                                     codegen::intrin 
::DispatchFastErf);
 }  // namespace intrin
 
 namespace legalize {
diff --git a/tests/python/relax/test_vm_instrument.py 
b/tests/python/relax/test_vm_instrument.py
new file mode 100644
index 0000000000..8297da1b74
--- /dev/null
+++ b/tests/python/relax/test_vm_instrument.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+
+from tvm import relax
+from tvm.relax.testing import nn
+from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
+
+
+def get_exec(data_shape):
+    builder = relax.BlockBuilder()
+    weight1_np = np.random.randn(64, 64).astype("float32")
+    weight2_np = np.random.randn(64, 64).astype("float32")
+
+    with builder.function("main"):
+        model = nn.Sequential(
+            nn.Linear(data_shape[1], weight1_np.shape[0], bias=False),
+            nn.ReLU(),
+            nn.Linear(weight2_np.shape[0], weight2_np.shape[1], bias=False),
+            nn.ReLU(),
+        )
+        data = nn.Placeholder(data_shape, name="data")
+        output = model(data)
+        params = [data] + model.parameters()
+        builder.emit_func_output(output, params=params)
+
+    mod = builder.get()
+
+    params = {"linear_weight": weight1_np, "linear_weight1": weight2_np}
+    mod = relax.transform.BindParams("main", params)(mod)
+
+    target = "llvm"
+    return relax.build(mod, target)
+
+
+def test_conv2d_cpu():
+    data_np = np.random.randn(1, 64).astype("float32")
+    ex = get_exec(data_np.shape)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    hit_count = {}
+
+    def instrument(func, name, before_run, ret_val, *args):
+        if (name, before_run) not in hit_count:
+            hit_count[(name, before_run)] = 0
+        hit_count[(name, before_run)] += 1
+        assert callable(func)
+        if before_run:
+            assert ret_val is None
+        if name == "matmul":
+            return relax.VMInstrumentReturnKind.SKIP_RUN
+
+    vm.set_instrument(instrument)
+    vm["main"](tvm.nd.array(data_np))
+    assert hit_count[("matmul", True)] == 2
+    assert ("matmul", False) not in hit_count
+    assert hit_count[("relu", True)] == 2
+    assert hit_count[("relu", False)] == 2
+
+
+def test_lib_comparator():
+    data_np = np.random.randn(1, 64).astype("float32")
+    ex = get_exec(data_np.shape)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    # compare against library module
+    cmp = LibCompareVMInstrument(vm.module.imported_modules[0], tvm.cpu(), 
verbose=False)
+    vm.set_instrument(cmp)
+    vm["main"](tvm.nd.array(data_np))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html
index becc92ee71..4a7b0864e1 100644
--- a/web/apps/browser/rpc_server.html
+++ b/web/apps/browser/rpc_server.html
@@ -44,8 +44,8 @@
     }
 
     function fetchProgressCallback(report) {
-      document.getElementById("progress-tracker-label").innerHTML = 
report.text;
-      document.getElementById("progress-tracker-progress").value = 
(report.fetchedBytes / report.totalBytes) * 100;
+      document.getElementById("rpc-progress-tracker-label").innerHTML = 
report.text;
+      document.getElementById("rpc-progress-tracker-progress").value = 
(report.fetchedBytes / report.totalBytes) * 100;
     }
 
     function connectRPC() {
@@ -130,9 +130,8 @@
     <button onclick="connectRPC()">Connect To Proxy</button>
     <button onclick="clearLog()">Clear Log</button>
     <div id="progress">
-      <label id="gpu-tracker-label"> </label><br>
-      <label id="progress-tracker-label"> </label> <br>
-      <progress id="progress-tracker-progress" max="100" value="100"> 
</progress>
+      <label id="rpc-progress-tracker-label"> </label> <br>
+      <progress id="rpc-progress-tracker-progress" max="100" value="100"> 
</progress>
     </div>
     <div id="includeRPCPlugin"></div>
     <div id="log"></div>
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
index 960c0ae18b..24acb7ece4 100644
--- a/web/src/rpc_server.ts
+++ b/web/src/rpc_server.ts
@@ -137,6 +137,7 @@ export class RPCServer {
       this.globalObjects.forEach(obj => {
         obj.dispose();
       });
+      this.log(this.inst.runtimeStatsText());
       this.inst.dispose();
     }
     if (this.state == RPCServerState.ReceivePacketHeader) {
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 5c2b17fc25..4550ba9837 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -1031,6 +1031,16 @@ export class Instance implements Disposable {
     this.ctx.dispose();
     this.lib.dispose();
   }
+  /**
+   * Obtain the runtime information in readable format.
+   */
+  runtimeStatsText(): string {
+    if (this.lib.webGPUContext !== undefined) {
+      return this.lib.webGPUContext.runtimeStatsText();
+    } else {
+      return "";
+    }
+  }
 
   /**
    * Begin a new scope for tracking object disposal.
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index bc466a1543..57a9be8cf9 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -271,6 +271,7 @@ class CanvaRenderManager implements Disposable {
   }
 }
 
+
 /**
  * WebGPU context
  * Manages all the webgpu resources here.
@@ -278,11 +279,24 @@ class CanvaRenderManager implements Disposable {
 export class WebGPUContext {
   device: GPUDevice;
   memory: Memory;
-
-  //private readBuffer:;
+  // internal data
   private bufferTable: Array<GPUBuffer | undefined> = [undefined];
   private bufferTableFreeId: Array<number> = [];
   private canvasRenderManager?: CanvaRenderManager = undefined;
+  // flags for debugging
+  // stats of the runtime.
+  // peak allocation
+  private peakAllocatedBytes: number = 0;
+  // current allocation
+  private currAllocatedBytes: number = 0;
+  // all allocation(ignoring free)
+  private allAllocatedBytes: number = 0;
+  // shader submit counter
+  private shaderSubmitCounter: number = 0;
+  // limite number of shaders to be submitted, useful for debugging, default 
to -1
+  protected debugShaderSubmitLimit: number = -1;
+  // log and sync each step
+  protected debugLogFinish: boolean = false;
 
   constructor(memory: Memory, device: GPUDevice) {
     this.memory = memory;
@@ -304,6 +318,16 @@ export class WebGPUContext {
     this.canvasRenderManager = undefined;
   }
 
+  /**
+   * Obtain the runtime information in readable format.
+   */
+  runtimeStatsText(): string {
+    let info = "peak-memory=" + Math.ceil(this.peakAllocatedBytes / (1 << 20)) 
+ " MB";
+    info += ", all-memory=" + Math.ceil(this.allAllocatedBytes / (1 << 20)) + 
" MB";
+    info += ", shader-submissions=" + this.shaderSubmitCounter;
+    return info;
+  }
+
   /**
    * Draw image from data in storage buffer.
    * @param ptr The GPU ptr
@@ -423,6 +447,12 @@ export class WebGPUContext {
     });
 
     const submitShader = (...args: Array<GPUPointer | number>): void => {
+      if (this.debugShaderSubmitLimit != -1 &&
+          this.shaderSubmitCounter >= this.debugShaderSubmitLimit) {
+        this.shaderSubmitCounter += 1;
+        return;
+      }
+
       const commandEncoder = this.device.createCommandEncoder();
       const compute = commandEncoder.beginComputePass();
       compute.setPipeline(pipeline);
@@ -470,6 +500,14 @@ export class WebGPUContext {
       compute.end()
       const command = commandEncoder.finish();
       this.device.queue.submit([command]);
+
+      if (this.debugLogFinish) {
+        const currCounter = this.shaderSubmitCounter;
+        this.device.queue.onSubmittedWorkDone().then(()=> {
+          console.log("["+ currCounter + "][Debug] finish shader" + 
finfo.name);
+        });
+      }
+      this.shaderSubmitCounter += 1;
     };
 
     return submitShader;
@@ -528,6 +566,11 @@ export class WebGPUContext {
       size: nbytes,
       usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | 
GPUBufferUsage.COPY_DST,
     });
+    this.currAllocatedBytes += nbytes;
+    this.allAllocatedBytes += nbytes;
+    if (this.currAllocatedBytes > this.peakAllocatedBytes) {
+      this.peakAllocatedBytes = this.currAllocatedBytes;
+    }
     const ptr = this.attachToBufferTable(buffer);
     return ptr;
   }
@@ -538,6 +581,7 @@ export class WebGPUContext {
     this.bufferTable[idx] = undefined;
     assert(buffer !== undefined);
     this.bufferTableFreeId.push(idx);
+    this.currAllocatedBytes -= buffer.size;
     buffer.destroy();
   }
 

Reply via email to