This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6f7952d7ad [Unity][WEBGPU] Codegen improvements and WebRuntime (#14187)
6f7952d7ad is described below
commit 6f7952d7ad3d7f56eda703f186ae9901698a6a3e
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Mar 3 23:52:38 2023 -0500
[Unity][WEBGPU] Codegen improvements and WebRuntime (#14187)
This PR makes various improvements web codegen in relax web runtime.
Correct support of shift operators.
Update relax vm to make most use of internal allocators.
Update the webgpu API to the latest spec.
---
include/tvm/runtime/relax_vm/memory_manager.h | 2 +-
python/tvm/contrib/tvmjs.py | 57 +++-
python/tvm/exec/rpc_proxy.py | 8 +-
src/runtime/relax_vm/memory_manager.cc | 2 +-
src/runtime/relax_vm/vm.cc | 76 +++--
src/target/intrin_rule.cc | 2 +-
src/target/source/codegen_webgpu.cc | 191 +++++++++---
src/target/source/codegen_webgpu.h | 7 +-
tests/lint/check_file_type.py | 1 +
web/apps/browser/rpc_plugin.html | 19 ++
web/apps/browser/rpc_server.html | 25 +-
web/emcc/wasm_runtime.cc | 35 +++
web/src/rpc_server.ts | 18 +-
web/src/runtime.ts | 136 ++++++++-
web/src/webgpu.ts | 402 +++++++++++++++++++++-----
15 files changed, 810 insertions(+), 171 deletions(-)
diff --git a/include/tvm/runtime/relax_vm/memory_manager.h
b/include/tvm/runtime/relax_vm/memory_manager.h
index e5ae8cfcfb..9234e9151c 100644
--- a/include/tvm/runtime/relax_vm/memory_manager.h
+++ b/include/tvm/runtime/relax_vm/memory_manager.h
@@ -61,7 +61,7 @@ class Allocator {
* \param dev The device where the array is allocated.
* \return The empty NDArray.
*/
- runtime::NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, Device
dev);
+ runtime::NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev);
/*! \brief Return the allocator type. */
inline AllocatorType type() const { return type_; }
/*! \brief Allocate a buffer given a size, alignment and type.
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 18cbf332c8..49626e725d 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -45,10 +45,16 @@ def _convert_f32_to_bf16(value):
return ((data + rounding_bias) >> 16).astype("uint16")
+def _convert_bf16_to_f32(value):
+ data = value.view("uint16")
+ return (data.astype("uint32") << 16).view("float32")
+
+
def dump_ndarray_cache(
params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
cachedir: str,
encode_format="f32-to-bf16",
+ meta_data=None,
):
"""Dump parameters to NDArray cache.
@@ -62,7 +68,14 @@ def dump_ndarray_cache(
encode_format: {"f32-to-bf16", "raw"}
Encoding format.
+
+ meta_data: json-compatible-struct
+ Extra meta_data to be stored in the cache json file.
"""
+ if encode_format not in ("raw", "f32-to-bf16"):
+ raise ValueError(f"Invalie encode_format {encode_format}")
+
+ meta_data = {} if meta_data is None else meta_data
records = []
total = len(params)
counter = 0
@@ -101,8 +114,9 @@ def dump_ndarray_cache(
sys.stdout.write(flush + last_cmd)
nd_cache_json = os.path.join(cachedir, "ndarray-cache.json")
+
with open(nd_cache_json, "w") as outfile:
- json.dump(records, outfile, indent=4)
+ json.dump({"metadata": meta_data, "records": records}, outfile,
indent=4)
print("\nAll finished, record saved to %s" % nd_cache_json)
if f32_to_bf16_triggered:
@@ -115,5 +129,44 @@ def dump_ndarray_cache(
b16_nd_cache_json = os.path.join(cachedir, "ndarray-cache-b16.json")
# also dump a file that contains bf16
with open(b16_nd_cache_json, "w") as outfile:
- json.dump(rec_bf16, outfile, indent=4)
+ json.dump({"metadata": meta_data, "records": rec_bf16}, outfile,
indent=4)
print("Also saved a bf16 record to %s" % b16_nd_cache_json)
+
+
+def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device):
+ """Load the ndarray cache from the directory or json.
+
+
+ Parameters
+ ----------
+ cachepath: str
+ Path to the location or json file.
+
+ device: tvm.runtime.Device
+ The device we would like to load the data from.
+ """
+ if not cachepath.endswith(".json"):
+ cachepath = os.path.join(cachepath, "ndarray-cache.json")
+
+ cachedir = os.path.dirname(cachepath)
+ json_info = json.loads(open(cachepath, "r").read())
+ result_dict = {}
+
+ for rec in json_info["records"]:
+ name = rec["name"]
+ shape = rec["shape"]
+ dtype = rec["dtype"]
+ encode_format = rec["format"]
+ data_path = rec["dataPath"]
+
+ arr = tvm.nd.empty(shape, dtype, device=device)
+ full_data_path = os.path.join(cachedir, data_path)
+
+ if encode_format == "f32-to-bf16":
+ data = np.fromfile(full_data_path, dtype="uint16").reshape(shape)
+ arr.copyfrom(_convert_bf16_to_f32(data))
+ else:
+ data = np.fromfile(full_data_path, dtype=dtype).reshape(shape)
+ arr.copyfrom(data)
+ result_dict[name] = arr
+ return result_dict, json_info["metadata"]
diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py
index d340750785..8cf1e4010b 100644
--- a/python/tvm/exec/rpc_proxy.py
+++ b/python/tvm/exec/rpc_proxy.py
@@ -28,12 +28,14 @@ def find_example_resource():
curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
base_path = os.path.abspath(os.path.join(curr_path, "..", "..", ".."))
index_page = os.path.join(base_path, "web", "apps", "browser",
"rpc_server.html")
+ default_plugin_page = os.path.join(base_path, "web", "apps", "browser",
"rpc_plugin.html")
+
resource_files = [
("/", os.path.join(base_path, "web", "dist", "tvmjs.bundle.js")),
("/", os.path.join(base_path, "web", "dist", "wasm",
"tvmjs_runtime.wasi.js")),
("/", index_page),
]
- allow_format = ("json", "bin", "js", "wasm")
+ allow_format = ("json", "bin", "js", "wasm", "html")
# recursively apend things in www, up to two levels
resource_bases = [
@@ -54,6 +56,10 @@ def find_example_resource():
fname = item[-1]
if not os.path.exists(fname):
raise RuntimeError("Cannot find %s" % fname)
+
+ if not any(item[-1].endswith("rpc_plugin.html") for item in
resource_files):
+ resource_files.append(("/", default_plugin_page))
+
return index_page, resource_files
diff --git a/src/runtime/relax_vm/memory_manager.cc
b/src/runtime/relax_vm/memory_manager.cc
index a017b9c6d9..339045f515 100644
--- a/src/runtime/relax_vm/memory_manager.cc
+++ b/src/runtime/relax_vm/memory_manager.cc
@@ -162,7 +162,7 @@ Allocator* MemoryManager::GetAllocator(Device dev) {
return it->second.get();
}
-runtime::NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType
dtype, DLDevice dev) {
+runtime::NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice
dev) {
VerifyDataType(dtype);
runtime::NDArray::Container* container =
new runtime::NDArray::Container(nullptr, shape, dtype, dev);
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 8679b2a793..9a3ce50bcc 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -21,7 +21,6 @@
* \file src/runtime/relax_vm/vm.cc
*/
-#include <tvm/runtime/container/adt.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/profiling.h>
#include <tvm/runtime/relax_vm/vm.h>
@@ -72,43 +71,46 @@ PackedFunc VMClosure::BindLastArgs(PackedFunc func,
std::vector<TVMRetValue> las
ObjectRef IndexIntoNestedObject(ObjectRef obj, TVMArgs args, int
starting_arg_idx) {
for (int i = starting_arg_idx; i < args.size(); i++) {
// the object must be an ADT to be able to index into it
- if (!obj.as<ADTObj>()) {
+ if (!obj.as<ArrayNode>()) {
LOG(FATAL) << "ValueError: Attempted to index into an object that is not
an ADT.";
}
int index = args[i];
- auto adt = Downcast<ADT>(obj);
+ auto arr = Downcast<Array<ObjectRef>>(obj);
// make sure the index is in bounds
- if (index >= static_cast<int>(adt.size())) {
- LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " <<
adt.size() << ").";
+ if (index >= static_cast<int>(arr.size())) {
+ LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " <<
arr.size() << ").";
}
- obj = adt[index];
+ obj = arr[index];
}
return obj;
}
-NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev) {
+NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev, Allocator*
alloc) {
if (src->device.device_type == dev.device_type && src->device.device_id ==
dev.device_id) {
return src;
+ } else {
+ auto res = alloc->Empty(src.Shape(), src->dtype, dev);
+ res.CopyFrom(src);
+ return res;
}
- return src.CopyTo(dev);
}
-ObjectRef ConvertObjectToDevice(ObjectRef src, const Device& dev) {
+ObjectRef ConvertObjectToDevice(ObjectRef src, const Device& dev, Allocator*
alloc) {
if (src->IsInstance<NDArray::ContainerType>()) {
- return ConvertNDArrayToDevice(Downcast<NDArray>(src), dev);
- } else if (src->IsInstance<ADTObj>()) {
+ return ConvertNDArrayToDevice(Downcast<NDArray>(src), dev, alloc);
+ } else if (src->IsInstance<ArrayNode>()) {
std::vector<ObjectRef> ret;
- ADT adt = Downcast<ADT>(src);
- for (size_t i = 0; i < adt.size(); i++) {
- ret.push_back(ConvertObjectToDevice(adt[i], dev));
+ auto arr = Downcast<Array<ObjectRef>>(src);
+ for (size_t i = 0; i < arr.size(); i++) {
+ ret.push_back(ConvertObjectToDevice(arr[i], dev, alloc));
}
- return ADT(adt->tag, ret.begin(), ret.end());
+ return Array<ObjectRef>(ret.begin(), ret.end());
} else {
return src;
}
}
-TVMRetValue ConvertArgToDevice(TVMArgValue input, Device dev) {
+TVMRetValue ConvertArgToDevice(TVMArgValue input, Device dev, Allocator*
alloc) {
// NOTE: NDArray::FromExternalDLTensor is not safe
// in terms of memory-behavior.
// To be extra careful, we copy DLTensor.
@@ -117,19 +119,23 @@ TVMRetValue ConvertArgToDevice(TVMArgValue input, Device
dev) {
TVMRetValue ret;
if (input.type_code() == kTVMDLTensorHandle) {
- ret = NDArray::NewFromDLTensor(input, dev);
+ DLTensor* tensor = input;
+ std::vector<int64_t> shape(tensor->shape, tensor->shape + tensor->ndim);
+ auto dst = alloc->Empty(shape, tensor->dtype, dev);
+ dst.CopyFrom(tensor);
+ ret = dst;
} else if (input.IsObjectRef<ObjectRef>()) {
- ret = ConvertObjectToDevice(input.operator ObjectRef(), dev);
+ ret = ConvertObjectToDevice(input.operator ObjectRef(), dev, alloc);
} else {
ret = input;
}
return ret;
}
-TVMRetValue ConvertRegToDevice(TVMRetValue input, Device dev) {
+TVMRetValue ConvertRegToDevice(TVMRetValue input, Device dev, Allocator*
alloc) {
TVMRetValue ret;
if (input.IsObjectRef<ObjectRef>()) {
- ret = ConvertObjectToDevice(input.operator ObjectRef(), dev);
+ ret = ConvertObjectToDevice(input.operator ObjectRef(), dev, alloc);
} else {
ret = input;
}
@@ -196,11 +202,13 @@ class VirtualMachineImpl : public VirtualMachine {
* \param args args[offset:] are arguments to the function. If the arguments
are not of the
* correct device for the function, they will be copied to the device.
* \param offset Starting offset of the arguments in \p args.
- * \note This interface works when using VM over RPC by internally
converting NDArray in
+ * \param with_param_module If set to true, the last argument will be a
module and can be invoked
+ * to get the argument, this is mainly used for debugging purposes
and setting composite
+ * objects. \note This interface works when using VM over RPC by internally
converting NDArray in
* the arguments to DLTensor, which is supported in RPC where remote could
only have a minimal C
* runtime.
*/
- void SetInput(std::string func_name, TVMArgs args, int offset);
+ void SetInput(std::string func_name, TVMArgs args, int offset, bool
with_param_module = false);
/*!
* \brief Look up whether the VM has a function by the given name.
@@ -401,7 +409,7 @@ void VirtualMachineImpl::Init(const std::vector<Device>&
devices,
if (constant.type_code() != kTVMNDArrayHandle) {
this->const_pool_.push_back(constant);
} else {
- this->const_pool_.push_back(ConvertRegToDevice(constant, devices[0]));
+ this->const_pool_.push_back(ConvertRegToDevice(constant, devices[0],
allocators[0]));
}
}
// Setup function sections.
@@ -479,8 +487,8 @@ PackedFunc VirtualMachineImpl::GetFunction(const
std::string& name,
// use remaining args as indices
ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(),
args, 1);
// after chasing through the indices, examine the final object
- if (const auto* adt = obj.as<ADTObj>()) {
- *rv = static_cast<int>(adt->size);
+ if (const auto* arr = obj.as<ArrayNode>()) {
+ *rv = static_cast<int>(arr->size());
} else {
*rv = -1;
}
@@ -491,7 +499,7 @@ PackedFunc VirtualMachineImpl::GetFunction(const
std::string& name,
RegType out = LookupVMOutput(func_name);
// use remaining args as indices
ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(),
args, 1);
- if (obj.as<ADTObj>()) {
+ if (obj.as<ArrayNode>()) {
LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC
compatibility. "
"Please specify another index argument.";
return;
@@ -501,6 +509,9 @@ PackedFunc VirtualMachineImpl::GetFunction(const
std::string& name,
} else if (name == "set_input") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
SetInput(args[0], args, 1); });
+ } else if (name == "set_input_with_param_module") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
SetInput(args[0], args, 1, true); });
} else if (name == "get_function_arity") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
@@ -527,7 +538,8 @@ PackedFunc VirtualMachineImpl::GetFunction(const
std::string& name,
}
}
-void VirtualMachineImpl::SetInput(std::string func_name, TVMArgs args, int
offset) {
+void VirtualMachineImpl::SetInput(std::string func_name, TVMArgs args, int
offset,
+ bool with_param_module) {
const auto& m = exec_->func_map;
if (m.find(func_name) != m.end()) {
Index gf_idx = m.at(func_name);
@@ -536,9 +548,15 @@ void VirtualMachineImpl::SetInput(std::string func_name,
TVMArgs args, int offse
ICHECK_EQ(args.size() - offset, params_num)
<< "The number of provided parameters doesn't match the number of
arguments for";
std::vector<RegType> func_args(params_num);
+
for (int i = offset; i < args.size(); ++i) {
int index = i - offset;
- func_args[index] = ConvertArgToDevice(args[i], devices[0]);
+ if (with_param_module && i == args.size() - 1) {
+ // call param func to get the arguments(usually corresponds to param
pack.)
+ func_args[index] = (args[i].operator
Module()).GetFunction("get_params")();
+ } else {
+ func_args[index] = ConvertArgToDevice(args[i], devices[0],
allocators[0]);
+ }
}
inputs_[func_name] = func_args;
} else {
@@ -604,7 +622,7 @@ void VirtualMachineImpl::SaveClosure(const String&
func_name, const String& save
VMClosure clo = this->GetClosure(func_name);
std::vector<RegType> inputs(args.size());
for (int i = 0; i < args.size(); ++i) {
- inputs[i] = ConvertArgToDevice(args[i], this->devices[0]);
+ inputs[i] = ConvertArgToDevice(args[i], this->devices[0],
this->allocators[0]);
}
PackedFunc impl = VMClosure::BindLastArgs(clo->impl, inputs);
if (!include_return) {
diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc
index 398e24d251..ab9a2ff594 100644
--- a/src/target/intrin_rule.cc
+++ b/src/target/intrin_rule.cc
@@ -119,7 +119,7 @@
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
DispatchPureExtern<FloatSuffix>);
PrimExpr DispatchFastErf(const PrimExpr& e) {
- LOG(WARNING) << "fast_erf will be used instead of erf";
+ DLOG(WARNING) << "fast_erf will be used instead of erf";
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index e4ccef88b6..ff9267ea7a 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -23,11 +23,13 @@
#include "codegen_webgpu.h"
#include <tvm/arith/analyzer.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <string>
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -39,6 +41,63 @@
namespace tvm {
namespace codegen {
+// WebGPU Info
+struct WebGPUWorkGroupInfo {
+ int workgroup_size[3] = {1, 1, 1};
+ // whether we have ref to block index z is used.
+ bool has_block_index_z{false};
+ // set of handles that have write access
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> write_access_set;
+};
+
+class WebGPUWorkgroupInfoCollector : public StmtExprVisitor {
+ public:
+ static WebGPUWorkGroupInfo Collect(const Stmt& stmt) {
+ WebGPUWorkgroupInfoCollector collector;
+ collector(stmt);
+ return collector.info_;
+ }
+
+ private:
+ void VisitExpr_(const VarNode* op) final {
+ StmtExprVisitor::VisitExpr_(op);
+ Var buffer_var = GetRef<Var>(op);
+ if (buffer_var.dtype().is_handle()) {
+ info_.write_access_set.insert(buffer_var);
+ }
+ }
+
+ void VisitStmt_(const BufferStoreNode* op) final {
+ StmtExprVisitor::VisitStmt_(op);
+ info_.write_access_set.insert(op->buffer->data);
+ }
+
+ void VisitStmt_(const AttrStmtNode* op) final {
+ // record workgroup size
+ if (op->attr_key == tir::attr::thread_extent) {
+ IterVar iv = Downcast<IterVar>(op->node);
+ if (iv->thread_tag.length() != 0) {
+ runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
+ if (ts.rank == 1) {
+ ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized
out by here";
+ ICHECK_LT(ts.dim_index, 3);
+ auto* sizeptr = op->value.as<tir::IntImmNode>();
+ ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group
size "
+ << " get " << op->value;
+ info_.workgroup_size[ts.dim_index] =
static_cast<uint32_t>(sizeptr->value);
+ } else if (ts.rank == 0) {
+ if (ts.dim_index == 2) {
+ info_.has_block_index_z = true;
+ }
+ }
+ }
+ }
+ // normal operation
+ StmtExprVisitor::VisitStmt_(op);
+ }
+ WebGPUWorkGroupInfo info_;
+};
+
std::string CodeGenWebGPU::Finish() {
return decl_stream.str() + this->fwd_decl_stream.str() + stream.str();
}
@@ -51,12 +110,11 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
alloc_storage_scope_[arg.get()] = "global";
}
}
- std::fill(workgroup_size_, workgroup_size_ + 3, 1);
}
CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
-void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
+runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from
@@ -64,6 +122,7 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
// Setup the thread group info.
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
+ ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim");
// add to alloc buffer type.
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
@@ -73,13 +132,22 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
decl_stream << "//----------------------------------------\n"
<< "// function: " << global_symbol.value() << "\n"
<< "//----------------------------------------\n";
+ runtime::FunctionInfo func_info;
+ func_info.name = global_symbol.value();
+
+ WebGPUWorkGroupInfo info = WebGPUWorkgroupInfoCollector::Collect(f->body);
std::vector<Var> pod_args;
int num_buffer = 0;
+
+ // add param_access modes info to launch params
+ std::ostringstream os_param_access;
+ os_param_access << "paramWriteAccess:[";
// setup buffer argumemts
for (Var arg : f->params) {
DataType t = arg.dtype();
if (t.is_handle()) {
+ func_info.arg_types.push_back(t);
auto* ptr = arg->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a
type_annotation as a "
"PointerType, "
@@ -95,8 +163,20 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
value_storage_type =
boolean_storage_type_.with_lanes(value_storage_type.lanes());
}
std::string vid = AllocVarID(arg.get());
+ std::string access_mode;
+ if (num_buffer != 0) {
+ os_param_access << ",";
+ }
+ if (info.write_access_set.count(arg)) {
+ access_mode = "read_write";
+ os_param_access << "1";
+ } else {
+ access_mode = "read";
+ os_param_access << "0";
+ }
+ // add extra access mode info to launch params
this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
- << "var<storage, read_write> " << vid << " : array<";
+ << "var<storage, " << access_mode << "> " << vid << "
: array<";
this->PrintType(value_storage_type, this->decl_stream);
this->decl_stream << ">;\n";
} else {
@@ -104,15 +184,33 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
}
}
+ // setup thread tags and param access in launch param tags;
+ if (auto opt =
f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis)) {
+ auto thread_axis = opt.value();
+ for (size_t i = 0; i < thread_axis.size(); ++i) {
+ func_info.launch_param_tags.push_back(thread_axis[i]->thread_tag);
+ }
+ }
+ os_param_access << "]";
+ func_info.launch_param_tags.push_back(os_param_access.str());
+
if (pod_args.size() != 0) {
// setup POD arguments
// TODO(tvm-team): store as a uniform, readonly buffer.
LOG(FATAL) << "Do not support pod arguments for now";
}
+
+ ICHECK(!info.has_block_index_z)
+ << "blockIdx.z is not supported in WebGPU to accomodate large
blockIdx.x";
+ // anotate workgroup
+ this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", "
+ << info.workgroup_size[1] << ", " << info.workgroup_size[2] <<
")\n";
+
// add to alloc buffer type.
// Function header.
this->stream << "fn main(\n"
<< " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
+ << " @builtin(num_workgroups) gridDim : vec3<u32>,\n"
<< " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
<< ") {\n";
// the function scope.
@@ -121,39 +219,26 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
- // anotate workgroup
- this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0]
<< ", "
- << workgroup_size_[1] << ", " << workgroup_size_[2] <<
")\n";
-}
-
-void CodeGenWebGPU::VisitStmt_(const AttrStmtNode* op) {
- // record workgroup size
- if (op->attr_key == tir::attr::thread_extent) {
- IterVar iv = Downcast<IterVar>(op->node);
- if (iv->thread_tag.length() != 0) {
- runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
- if (ts.rank == 1) {
- ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out
by here";
- ICHECK_LT(ts.dim_index, 3);
- auto* sizeptr = op->value.as<tir::IntImmNode>();
- ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group
size "
- << " get " << op->value;
- workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
- }
- }
- }
- // normal operation
- CodeGenC::VisitStmt_(op);
+ return func_info;
}
void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
std::ostringstream os;
PrintType(iv->var.dtype(), os);
- os << "(" << iv->thread_tag << ")";
- std::string tidx = os.str();
- this->MarkConst(tidx);
- var_idmap_[iv->var.get()] = tidx;
+ if (iv->thread_tag == "blockIdx.x") {
+ // WebGPU have restriction to limit the maximum size of blockId.x to be
65535
+ // We allow runtime to spread the load out to blockIdx.z so it can be a
large number.
+ os << "(blockIdx.z * gridDim.x + blockIdx.x)";
+ std::string tidx = os.str();
+ std::string aggregated_bidx = SSAGetID(os.str(), iv->var.dtype());
+ var_idmap_[iv->var.get()] = aggregated_bidx;
+ } else {
+ os << "(" << iv->thread_tag << ")";
+ std::string tidx = os.str();
+ this->MarkConst(tidx);
+ var_idmap_[iv->var.get()] = tidx;
+ }
}
void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
@@ -179,8 +264,10 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream&
os) { // NOLINT(*)
ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support
f16 or f32";
os << "f" << t.bits();
} else if (t.is_uint()) {
+ ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support u64";
os << "u" << t.bits();
} else if (t.is_int()) {
+ ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support i64";
os << "i" << t.bits();
} else {
LOG(FATAL) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU
type";
@@ -221,6 +308,10 @@ void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { //
os << ')';
}
+PrimExpr CodeGenWebGPU::EnforceU32(PrimExpr value) {
+ return cast(DataType::UInt(32, value.dtype().lanes()), value);
+}
+
void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { //
NOLINT(*)
if (op->op.same_as(builtin::reinterpret())) {
// generate bitcast<TYPE>(ARG)
@@ -229,7 +320,23 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLIN
os << ">(";
this->PrintExpr(op->args[0], os);
os << ")";
+ } else if (op->op.same_as(builtin::shift_right())) {
+ os << '(';
+ this->PrintExpr(op->args[0], os);
+ os << ">>";
+ // WebGPU requires shift bits to be u32.
+ this->PrintExpr(EnforceU32(op->args[1]), os);
+ os << ')';
+ } else if (op->op.same_as(builtin::shift_left())) {
+ os << '(';
+ this->PrintExpr(op->args[0], os);
+ os << "<<";
+ // WebGPU requires shift bits to be u32.
+ this->PrintExpr(EnforceU32(op->args[1]), os);
+ os << ')';
} else if (op->op.same_as(builtin::if_then_else())) {
+ this->PrintExpr(Select(op->args[0], op->args[1], op->args[2]), os);
+ return;
// conditional that skips eval if cond evals to false
std::string result = name_supply_->FreshName("condval");
std::string cond = PrintExpr(op->args[0]);
@@ -241,14 +348,16 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLIN
this->stream << "if (" << cond << ") {\n";
{
int then_scope = this->BeginScope();
+ std::string true_val = PrintExpr(op->args[1]);
this->PrintIndent();
- this->stream << result << " = " << PrintExpr(op->args[1]) << ";\n} else
{\n";
+ this->stream << result << " = " << true_val << ";\n} else {\n";
this->EndScope(then_scope);
}
{
int else_scope = this->BeginScope();
+ std::string false_val = PrintExpr(op->args[2]);
this->PrintIndent();
- this->stream << result << " = " << PrintExpr(op->args[2]) << ";\n}\n";
+ this->stream << result << " = " << false_val << ";\n}\n";
this->EndScope(else_scope);
}
os << result;
@@ -444,10 +553,13 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
PrintType(op->dtype, this->decl_stream);
this->decl_stream << ", " << constant_size << ">;\n";
} else if (storage_scope.rank == runtime::StorageRank::kLocal) {
- this->PrintIndent();
- this->stream << "var " << vid << " : array<";
- PrintType(op->dtype, this->stream);
- this->stream << ", " << constant_size << ">;\n";
+ this->decl_stream << "var<private> " << vid << " : array<";
+ PrintType(op->dtype, this->decl_stream);
+ this->decl_stream << ", " << constant_size << ">;\n";
+ // this->PrintIndent();
+ // this->stream << "var " << vid << " : array<";
+ // PrintType(op->dtype, this->stream);
+ // this->stream << ", " << constant_size << ">;\n";
} else {
LOG(FATAL) << "WebGPU: Do not support storage scope: " <<
storage_scope.to_string();
}
@@ -527,6 +639,8 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
bool output_ssa = false;
std::unordered_map<std::string, std::string> smap;
+ std::unordered_map<std::string, runtime::FunctionInfo> fmap;
+
for (auto kv : mod->functions) {
CodeGenWebGPU cg(target);
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only
take PrimFunc";
@@ -539,11 +653,12 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol
attribute";
std::string f_name = global_symbol.value();
cg.Init(output_ssa);
- cg.AddFunction(f);
+ fmap[f_name] = cg.AddFunction(f);
std::string code = cg.Finish();
smap[f_name] = code;
}
- auto n = make_object<WebGPUSourceModuleNode>(smap, ExtractFuncInfo(mod));
+
+ auto n = make_object<WebGPUSourceModuleNode>(smap, fmap);
return runtime::Module(n);
}
diff --git a/src/target/source/codegen_webgpu.h
b/src/target/source/codegen_webgpu.h
index 57f226ba8a..47f94091a1 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -48,7 +48,7 @@ class CodeGenWebGPU final : public CodeGenC {
explicit CodeGenWebGPU(Target target);
// overrides
std::string Finish() final;
- void AddFunction(const PrimFunc& f); // NOLINT(*)
+ runtime::FunctionInfo AddFunction(const PrimFunc& f); // NOLINT(*)
void InitFuncState(const PrimFunc& f) final;
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
@@ -71,15 +71,14 @@ class CodeGenWebGPU final : public CodeGenC {
void VisitStmt_(const BufferStoreNode* op) final;
void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const AllocateNode* op) final;
- void VisitStmt_(const AttrStmtNode* op) final;
void VisitStmt_(const AssertStmtNode* op) final;
void VisitStmt_(const AllocateConstNode* op) final;
private:
/*!
- * \brief Records the workgroup size of the kernel.
+ * \brief Enforce value to be U32.
*/
- uint32_t workgroup_size_[3];
+ static PrimExpr EnforceU32(PrimExpr value);
/*!
* \brief Storage type of bool values.
*/
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index 56f812c867..7753961c17 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -130,6 +130,7 @@ ALLOW_SPECIFIC_FILE = {
"apps/wasm-standalone/wasm-graph/.cargo/config",
# html for demo purposes
"web/apps/browser/rpc_server.html",
+ "web/apps/browser/rpc_plugin.html",
# images are normally not allowed
# discuss with committers before add more images
"apps/android_rpc/app/src/main/res/mipmap-hdpi/ic_launcher.png",
diff --git a/web/apps/browser/rpc_plugin.html b/web/apps/browser/rpc_plugin.html
new file mode 100644
index 0000000000..87df60d42b
--- /dev/null
+++ b/web/apps/browser/rpc_plugin.html
@@ -0,0 +1,19 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+<!--- Plugin module -->
+<canvas id="canvas" width="224" height="224"></canvas>
diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html
index 8fa50272b2..a03e290daa 100644
--- a/web/apps/browser/rpc_server.html
+++ b/web/apps/browser/rpc_server.html
@@ -20,9 +20,15 @@
<head lang="en-US"></head>
<title>TVM RPC Test Page</title>
</head>
+
+ <meta http-equiv="origin-trial"
content="Agx76XA0ITxMPF0Z8rbbcMllwuxsyp9qdtQaXlLqu1JUrdHB6FPonuyIKJ3CsBREUkeioJck4nn3KO0c0kkwqAMAAABJeyJvcmlnaW4iOiJodHRwOi8vbG9jYWxob3N0Ojg4ODgiLCJmZWF0dXJlIjoiV2ViR1BVIiwiZXhwaXJ5IjoxNjkxNzExOTk5fQ==">
+
<script src="tvmjs_runtime.wasi.js"></script>
<script src="tvmjs.bundle.js"></script>
<script>
+ // Global environment
+ var tvmjsGlobalEnv = {};
+
function customLog(message) {
console.log(message);
const d = document.createElement("div");
@@ -38,8 +44,8 @@
}
function fetchProgressCallback(report) {
- document.getElementById("fetch-text").innerHTML = report.text;
- document.getElementById("fetch-progress").value = (report.fetchedBytes /
report.totalBytes) * 100;
+ document.getElementById("progress-tracker-label").innerHTML =
report.text;
+ document.getElementById("progress-tracker-progress").value =
(report.fetchedBytes / report.totalBytes) * 100;
}
function connectRPC() {
@@ -60,7 +66,8 @@
new tvmjs.RPCServer(
proxyUrl, key, getImports, customLog,
- ndarrayCacheUrl, ndarrayCacheDevice, fetchProgressCallback);
+ ndarrayCacheUrl, ndarrayCacheDevice, fetchProgressCallback,
+ tvmjsGlobalEnv.asyncOnRPCServerLoad);
}
async function loadCacheOption() {
@@ -79,6 +86,12 @@
} catch (err) {}
}
</script>
+<script src="https://code.jquery.com/jquery-3.6.3.min.js"
integrity="sha256-pvPw+upLPUjgMXY0G+8O0xUf+/Im1MZjXxxgOcBQBXU="
crossorigin="anonymous"></script>
+<script>
+ $(function(){
+ $("#includeRPCPlugin").load("rpc_plugin.html");
+ });
+</script>
<body onload="loadCacheOption()">
<h1>TVM WebSocket RPC Server</h1>
To use this page
@@ -117,10 +130,10 @@
<button onclick="connectRPC()">Connect To Proxy</button>
<button onclick="clearLog()">Clear Log</button>
<div id="progress">
- <label id="fetch-text"></div>
- <progress id="fetch-progress" max="100" value="100"> </progress>
+ <label id="progress-tracker-label"></div>
+ <progress id="progress-tracker-progress" max="100" value="100">
</progress>
</div>
+ <div id="includeRPCPlugin"></div>
<div id="log"></div>
- <canvas id="canvas" width="224" height="224"></canvas>
</body>
</html>
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index c90b917c5c..8f16365eee 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -185,5 +185,40 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string
bytes, std::string format)
}
TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);
+
+class ParamModuleNode : public runtime::ModuleNode {
+ public:
+ const char* type_key() const final { return "param_module"; }
+
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
+ if (name == "get_params") {
+ auto params = params_;
+ return PackedFunc([params](TVMArgs args, TVMRetValue* rv) { *rv =
params; });
+ } else {
+ return PackedFunc();
+ }
+ }
+
+ static Module Create(std::string prefix, int num_params) {
+ Array<NDArray> params;
+ for (int i = 0; i < num_params; ++i) {
+ std::string name = prefix + "_" + std::to_string(i);
+ auto opt = NDArrayCache::Get(name);
+ if (opt) {
+ params.push_back(opt.value());
+ } else {
+ LOG(FATAL) << "Cannot find " << name << " in cache";
+ }
+ }
+ auto n = make_object<ParamModuleNode>();
+ n->params_ = params;
+ return Module(n);
+ }
+
+ private:
+ Array<NDArray> params_;
+};
+
+TVM_REGISTER_GLOBAL("tvmjs.param_module_from_cache").set_body_typed(ParamModuleNode::Create);
} // namespace runtime
} // namespace tvm
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
index 4dd7228d3c..58601230e8 100644
--- a/web/src/rpc_server.ts
+++ b/web/src/rpc_server.ts
@@ -22,7 +22,6 @@ import { assert, StringToUint8Array, Uint8ArrayToString }
from "./support";
import { detectGPUDevice } from "./webgpu";
import * as compact from "./compact";
import * as runtime from "./runtime";
-import { timeStamp } from "console";
import { Disposable } from "./types";
enum RPCServerState {
@@ -85,6 +84,7 @@ export class RPCServer {
private ndarrayCacheUrl: string;
private ndarrayCacheDevice: string;
private fetchProgressCallback?: runtime.FetchProgressCallback;
+ private asyncOnServerLoad?: (inst: runtime.Instance) => Promise<void>;
private pendingSend: Promise<void> = Promise.resolve();
private name: string;
private inst?: runtime.Instance = undefined;
@@ -104,7 +104,8 @@ export class RPCServer {
logger: (msg: string) => void = console.log,
ndarrayCacheUrl: string = "",
ndarrayCacheDevice: string = "cpu",
- fetchProgressCallback: runtime.FetchProgressCallback | undefined =
undefined
+ fetchProgressCallback: runtime.FetchProgressCallback | undefined =
undefined,
+ asyncOnServerLoad: ((inst: runtime.Instance) => Promise<void>) | undefined
= undefined,
) {
this.url = url;
this.key = key;
@@ -114,7 +115,7 @@ export class RPCServer {
this.ndarrayCacheUrl = ndarrayCacheUrl;
this.ndarrayCacheDevice = ndarrayCacheDevice;
this.fetchProgressCallback = fetchProgressCallback;
-
+ this.asyncOnServerLoad = asyncOnServerLoad;
this.checkLittleEndian();
this.socket = compact.createWebSocket(url);
this.socket.binaryType = "arraybuffer";
@@ -143,7 +144,8 @@ export class RPCServer {
this.log("Automatic reconnecting..");
new RPCServer(
this.url, this.key, this.getImports, this.logger,
- this.ndarrayCacheUrl, this.ndarrayCacheDevice,
this.fetchProgressCallback);
+ this.ndarrayCacheUrl, this.ndarrayCacheDevice,
+ this.fetchProgressCallback, this.asyncOnServerLoad);
} else {
this.log("Closing the server, final state=" + this.state);
}
@@ -268,12 +270,15 @@ export class RPCServer {
this.getImports(),
this.logger
);
+
try {
const gpuDevice: GPUDevice | undefined | null = await
detectGPUDevice();
if (gpuDevice !== undefined && gpuDevice !== null) {
const label = gpuDevice.label?.toString() || "WebGPU";
this.log("Initialize GPU device: " + label);
inst.initWebGPU(gpuDevice);
+ } else {
+ this.log("Cannot find WebGPU device in the env");
}
} catch (err) {
this.log("Cannnot initialize WebGPU, " + err.toString());
@@ -281,7 +286,6 @@ export class RPCServer {
this.inst = inst;
// begin scope to allow handling of objects
- // the object should stay alive during all sessions.
this.inst.beginScope();
if (this.fetchProgressCallback !== undefined) {
this.inst.registerFetchProgressCallback(this.fetchProgressCallback);
@@ -297,8 +301,10 @@ export class RPCServer {
}
assert(this.inst !== undefined);
+ if (this.asyncOnServerLoad !== undefined) {
+ await this.asyncOnServerLoad(this.inst);
+ }
const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer");
-
const messageHandler = fcreate(
(cbytes: Uint8Array): runtime.Scalar => {
assert(this.inst !== undefined);
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 463532762e..1f3232c557 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -148,6 +148,7 @@ class RuntimeContext implements Disposable {
arrayCacheRemove: PackedFunc;
arrayCacheClear: PackedFunc;
arrayDecodeStorage: PackedFunc;
+ paramModuleFromCache: PackedFunc;
private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
@@ -161,6 +162,7 @@ class RuntimeContext implements Disposable {
this.arrayCacheUpdate = getGlobalFunc("tvmjs.ndarray_cache.update");
this.arrayCacheClear = getGlobalFunc("tvmjs.ndarray_cache.clear");
this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
+ this.paramModuleFromCache = getGlobalFunc("tvmjs.param_module_from_cache");
}
@@ -175,6 +177,7 @@ class RuntimeContext implements Disposable {
this.arrayCacheUpdate.dispose();
this.arrayCacheClear.dispose();
this.arrayDecodeStorage.dispose();
+ this.paramModuleFromCache.dispose();
}
beginScope() : void {
@@ -410,7 +413,7 @@ export class NDArray implements Disposable {
/** Device of the array. */
device: DLDevice;
/** Whether it is a temporary view that can become invalid after the call. */
- private isView: boolean;
+ isView: boolean;
private byteOffset: number;
private dltensor: Pointer;
private dataPtr: Pointer;
@@ -479,6 +482,18 @@ export class NDArray implements Disposable {
return this.handle;
}
+ /**
+ * Get dataPtr of NDarray
+ *
+ * @returns The handle.
+ */
+ getDataPtr(): Pointer {
+ if (this.handle == 0) {
+ throw Error("NDArray has already been disposed");
+ }
+ return this.dataPtr;
+ }
+
dispose(): void {
if (this.handle != 0 && !this.isView) {
this.lib.checkCall(
@@ -539,9 +554,12 @@ export class NDArray implements Disposable {
* @returns this
*/
copyFromRawBytes(data: Uint8Array): this {
- if (this.device.deviceType != DeviceStrToEnum.cpu) {
- throw new Error("Can only sync copy CPU array, use
cpu_arr.copyfrom(gpu_arr) then sync instead.");
+ // short cut for gpu copy
+ if (this.device.deviceType == DeviceStrToEnum.webgpu) {
+ this.lib.webGPUContext?.copyRawBytesToBuffer(data, this.getDataPtr(), 0,
data.length);
+ return this;
}
+ // CPU copy
const size = this.shape.reduce((a, b) => {
return a * b;
}, 1);
@@ -910,6 +928,7 @@ export type FetchProgressCallback = (report:
FetchProgressReport) => void;
export class Instance implements Disposable {
memory: Memory;
exports: Record<string, Function>;
+ cacheMetadata: Record<string, any> = {};
private lib: FFILibrary;
private env: Environment;
private objFactory: Map<number, FObjectConstructor>;
@@ -951,7 +970,6 @@ export class Instance implements Disposable {
env = new Environment(importObject);
wasmInstance = new WebAssembly.Instance(wasmModule, env.imports);
}
-
env.start(wasmInstance);
this.env = env;
this.lib = new FFILibrary(wasmInstance, env.imports);
@@ -977,7 +995,7 @@ export class Instance implements Disposable {
* @number The number of times to compute the average.
* @repeat The number of times to repeat the run.
*/
- async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4):
Promise<number[]> {
+ async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=1):
Promise<number[]> {
// Skip first run as it can involve GPU warmup and module loading time.
const perf = compact.getPerformance();
const results = [];
@@ -999,6 +1017,8 @@ export class Instance implements Disposable {
}
dispose(): void {
+ // dispose canvas resource
+ this.lib.webGPUContext?.disposeCanvas();
// order matters
// ctx release goes back into lib.
this.ctx.dispose();
@@ -1016,7 +1036,7 @@ export class Instance implements Disposable {
* End a scope and release all created TVM objects
* under the current scope.
*
- * Exception: one can call retainToParentScope to move
+ * Exception: one can call {@link moveToParentScope} to move
* a value to parent scope.
*/
endScope(): void {
@@ -1030,7 +1050,7 @@ export class Instance implements Disposable {
* @returns The result value.
*
* @note For action to return a valid value,
- * we will need to call {@link retainToParentScope}
+ * we will need to call {@link moveToParentScope}
* for the objects that are created in the scope.
*/
withNewScope<T>(action: ()=>T): T {
@@ -1248,6 +1268,18 @@ export class Instance implements Disposable {
this.fetchProgressCallback.push(cb);
}
+ /**
+ * Get parameters in the form of prefix_i
+ *
+ * @param prefix The parameter prefix.
+ * @param numParams Number of parameters.
+ * @returns
+ */
+ getParamsFromCache(prefix: string, numParams: number) : TVMObject {
+ return (this.ctx.paramModuleFromCache(
+ prefix, new Scalar(numParams, "int32")) as
Module).getFunction("get_params")();
+ }
+
/**
* Get NDArray from cache.
* @param name The name of array.
@@ -1289,17 +1321,20 @@ export class Instance implements Disposable {
*
* @param ndarrayCacheUrl The cache url.
* @param device The device to be fetched to.
+ * @returns The meta data
*/
- async fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice) {
+ async fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice) :
Promise<any> {
const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href;
var list;
try {
-
list = await (await fetch(jsonUrl)).json();
} catch(err) {
this.env.logger("Cannot fetch " + jsonUrl);
}
- await this.fetchNDArrayCacheInternal(ndarrayCacheUrl, list as
Array<NDArrayCacheEntry>, device);
+ await this.fetchNDArrayCacheInternal(
+ ndarrayCacheUrl,
+ list["records"] as Array<NDArrayCacheEntry>, device);
+ this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as
Record<string, any>) };
}
/**
@@ -1334,14 +1369,14 @@ export class Instance implements Disposable {
const reportCallback = (iter: number)=> {
// report
for (let j = 0; j < this.fetchProgressCallback.length; ++j) {
- let text = "Fetching NDArray Cache[" + iter + "/" + list.length+ "]:";
+ let text = "Fetching param cache[" + iter + "/" + list.length+ "]:";
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB
fetched "
- text += "from " + Math.ceil(totalBytes / (1024 * 1024)).toString() +
"MB, "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "%
completed, "
text += timeElapsed + " secs elapsed";
if (timeElapsed != 0){
- text += ", speed=" + (fetchedBytes / timeElapsed / (1024 *
1024)).toFixed(1) + " MB/sec";
+ text += ", speed=" + (fetchedBytes / timeElapsed / (1024 *
1024)).toFixed(1) + " MB/sec."
}
+ text += " This can take a while during first load.";
this.fetchProgressCallback[j]({
fetchedBytes: fetchedBytes,
totalBytes: totalBytes,
@@ -1542,6 +1577,72 @@ export class Instance implements Disposable {
return ret;
}
+ /**
+ * Create am uniform {@link NDArray} with given shape.
+ *
+ * @param shape The shape of the array.
+ * @param low The low value.
+ * @param high The high value.
+ * @param dev The device of the ndarray.
+ * @returns The created ndarray.
+ */
+ uniform(
+ shape: Array<number>,
+ low: number,
+ high: number,
+ dev: DLDevice
+ ): NDArray {
+ const ret = this.empty(shape, "float32", dev);
+ const size = shape.reduce((a, b) => {
+ return a * b;
+ }, 1);
+ const scale = high - low;
+ const input = new Float32Array(size);
+ for (let i = 0; i < input.length; ++i) {
+ input[i] = low + Math.random() * scale;
+ }
+ return ret.copyFrom(input);
+ }
+
+ /**
+ * Bind canvas to the current WebGPU context
+ * @param canvas The canvas.
+ */
+ bindCanvas(canvas: HTMLCanvasElement) {
+ this.lib.webGPUContext?.bindCanvas(canvas);
+ }
+
+ /**
+ * Show image in canvas.
+ *
+ * @param dataRGBA Image array in height x width uint32 NDArray RGBA format
on GPU.
+ */
+ showImage(dataRGBA: NDArray) {
+ if (dataRGBA.shape.length != 2) {
+ throw Error("Require a height x width uint32 NDArray in RGBA" +
+ "get shape=" + dataRGBA.shape.toString() + " instead."
+ );
+ }
+ if (dataRGBA.device.deviceType != DeviceStrToEnum.webgpu) {
+ throw new Error("Can only run showImage on WebGPU array, " +
+ "get " + DeviceEnumToStr[dataRGBA.device.deviceType] + " instead.");
+ }
+ if (dataRGBA.dtype != "uint32") {
+ throw Error("Require a height x width uint32 NDArray in RGBA, " +
+ "get " + dataRGBA.dtype + " instead.");
+ }
+ this.lib.webGPUContext?.drawImageFromBuffer(
+ dataRGBA.getDataPtr(), dataRGBA.shape[0], dataRGBA.shape[1]
+ );
+ }
+
+ /**
+ * Clear canvas
+ */
+ clearCanvas() {
+ this.lib.webGPUContext?.clearCanvas();
+ }
+
/**
* Create an tuple {@link TVMArray} input array.
*
@@ -1773,8 +1874,13 @@ export class Instance implements Disposable {
const valueOffset = argsValue + i * SizeOf.TVMValue;
const codeOffset = argsCode + i * SizeOf.I32;
if (val instanceof NDArray) {
- stack.storePtr(valueOffset, val.getHandle());
- stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle);
+ if (!val.isView) {
+ stack.storePtr(valueOffset, val.getHandle());
+ stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle);
+ } else {
+ stack.storePtr(valueOffset, val.getHandle());
+ stack.storeI32(codeOffset, ArgTypeCode.TVMDLTensorHandle);
+ }
} else if (val instanceof Scalar) {
if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) {
stack.storeI64(valueOffset, val.value);
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index faf6fac990..8b5d2ee543 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -20,6 +20,7 @@ import "@webgpu/types";
import { assert } from "./support";
import { Pointer } from "./ctypes";
import { Memory } from "./memory";
+import { Disposable } from "./types";
/** A pointer to points to the raw address space. */
export type GPUPointer = number;
@@ -30,7 +31,12 @@ export type GPUPointer = number;
export async function detectGPUDevice(): Promise<GPUDevice | undefined | null>
{
if (typeof navigator !== "undefined" && navigator.gpu !== undefined) {
const adapter = await navigator.gpu.requestAdapter();
- return await adapter?.requestDevice();
+ return await adapter?.requestDevice({
+ requiredLimits: {
+ maxStorageBufferBindingSize: 1 << 30,
+ maxComputeWorkgroupStorageSize: 32 << 10,
+ }
+ });
} else {
return undefined;
}
@@ -42,6 +48,214 @@ interface FunctionInfo {
launch_param_tags: Array<string>;
}
+const canvasRenderWGSL =`
+@group(0) @binding(0) var my_sampler : sampler;
+@group(0) @binding(1) var my_texture : texture_2d<f32>;
+
+struct VertexOutput {
+ @builtin(position) position : vec4<f32>,
+ @location(0) uv : vec2<f32>,
+}
+
+@vertex
+fn vertex_main(@builtin(vertex_index) vidx : u32) -> VertexOutput {
+ const pos = array(
+ vec2( 1.0, 1.0),
+ vec2( 1.0, -1.0),
+ vec2(-1.0, -1.0),
+ vec2( 1.0, 1.0),
+ vec2(-1.0, -1.0),
+ vec2(-1.0, 1.0),
+ );
+
+ const uv = array(
+ vec2(1.0, 0.0),
+ vec2(1.0, 1.0),
+ vec2(0.0, 1.0),
+ vec2(1.0, 0.0),
+ vec2(0.0, 1.0),
+ vec2(0.0, 0.0),
+ );
+
+ var output : VertexOutput;
+ output.position = vec4(pos[vidx], 0.0, 1.0);
+ output.uv = uv[vidx];
+ return output;
+}
+
+@fragment
+fn fragment_main(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
+ return textureSample(my_texture, my_sampler, uv);
+}
+
+@fragment
+fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
+ return vec4(1.0, 1.0, 1.0, 1.0);
+}
+`
+class CanvaRenderManager implements Disposable {
+ private device: GPUDevice;
+ private canvasContext: GPUCanvasContext;
+ private stagingTexture: GPUTexture;
+ private renderSampler: GPUSampler;
+ private renderPipeline: GPURenderPipeline;
+ private clearPipeline: GPURenderPipeline;
+ private canvasTextureFormat: GPUTextureFormat;
+
+ constructor(device: GPUDevice, canvas: HTMLCanvasElement) {
+ this.device = device;
+ const ctx = canvas.getContext("webgpu");
+ if (ctx == null) {
+ throw Error("Cannot bind WebGPU context");
+ }
+ this.canvasContext = ctx;
+ this.canvasTextureFormat = navigator.gpu.getPreferredCanvasFormat();
+ this.canvasContext.configure({
+ device: this.device,
+ format: this.canvasTextureFormat,
+ alphaMode: "opaque",
+ });
+
+ this.renderPipeline = device.createRenderPipeline({
+ layout: "auto",
+ vertex: {
+ module: device.createShaderModule({
+ code: canvasRenderWGSL,
+ }),
+ entryPoint: "vertex_main",
+ },
+ fragment: {
+ module: device.createShaderModule({
+ code: canvasRenderWGSL,
+ }),
+ entryPoint: "fragment_main",
+ targets: [{
+ format: this.canvasTextureFormat,
+ }],
+ },
+ primitive: {
+ topology: "triangle-list",
+ },
+ });
+
+ this.clearPipeline = device.createRenderPipeline({
+ layout: "auto",
+ vertex: {
+ module: device.createShaderModule({
+ code: canvasRenderWGSL,
+ }),
+ entryPoint: "vertex_main",
+ },
+ fragment: {
+ module: device.createShaderModule({
+ code: canvasRenderWGSL,
+ }),
+ entryPoint: "fragment_clear",
+ targets: [{
+ format: this.canvasTextureFormat,
+ }],
+ },
+ primitive: {
+ topology: "triangle-list",
+ },
+ });
+
+ this.renderSampler = device.createSampler({
+ magFilter: "linear",
+ minFilter: "linear",
+ });
+ // staging texture always be in RGBA
+ this.stagingTexture = device.createTexture({
+ size: [canvas.height, canvas.width, 1],
+ format: "rgba8unorm",
+ usage:
+ GPUTextureUsage.TEXTURE_BINDING |
+ GPUTextureUsage.COPY_DST |
+ GPUTextureUsage.RENDER_ATTACHMENT,
+ });
+ }
+
+ clear() {
+ const commandEncoder = this.device.createCommandEncoder();
+ const passEncoder = commandEncoder.beginRenderPass({
+ colorAttachments: [
+ {
+ view: this.canvasContext.getCurrentTexture().createView(),
+ clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 },
+ loadOp: "clear",
+ storeOp: "store",
+ },
+ ],
+ });
+ passEncoder.setPipeline(this.clearPipeline);
+ const renderBindingGroup = this.device.createBindGroup({
+ layout: this.renderPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: this.renderSampler },
+ { binding: 1, resource: this.stagingTexture.createView() },
+ ],
+ });
+ passEncoder.setBindGroup(0, renderBindingGroup);
+ passEncoder.draw(6, 1, 0, 0);
+ passEncoder.end();
+ this.device.queue.submit([commandEncoder.finish()]);
+ }
+
+ draw(buffer: GPUBuffer, height: number, width: number) {
+ // resize the staging texture
+ if (height != this.stagingTexture.height || width !=
this.stagingTexture.width) {
+ this.stagingTexture.destroy();
+ this.stagingTexture = this.device.createTexture({
+ size: [height, width, 1],
+ format: "rgba8unorm",
+ usage:
+ GPUTextureUsage.TEXTURE_BINDING |
+ GPUTextureUsage.COPY_DST |
+ GPUTextureUsage.RENDER_ATTACHMENT,
+ });
+ }
+
+ const commandEncoder = this.device.createCommandEncoder();
+ commandEncoder.copyBufferToTexture({
+ buffer: buffer,
+ offset: 0,
+ bytesPerRow: this.stagingTexture.width * 4
+ }, {
+ texture: this.stagingTexture
+ },{
+ width: this.stagingTexture.width,
+ height: this.stagingTexture.height
+ });
+
+ const passEncoder = commandEncoder.beginRenderPass({
+ colorAttachments: [
+ {
+ view: this.canvasContext.getCurrentTexture().createView(),
+ clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 },
+ loadOp: "clear",
+ storeOp: "store",
+ },
+ ],
+ });
+ passEncoder.setPipeline(this.renderPipeline);
+ const renderBindingGroup = this.device.createBindGroup({
+ layout: this.renderPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: this.renderSampler },
+ { binding: 1, resource: this.stagingTexture.createView() },
+ ],
+ });
+ passEncoder.setBindGroup(0, renderBindingGroup);
+ passEncoder.draw(6, 1, 0, 0);
+ passEncoder.end();
+ this.device.queue.submit([commandEncoder.finish()]);
+ }
+
+ dispose() : void {
+ this.stagingTexture.destroy();
+ }
+}
+
/**
* WebGPU context
* Manages all the webgpu resources here.
@@ -53,8 +267,7 @@ export class WebGPUContext {
//private readBuffer:;
private bufferTable: Array<GPUBuffer | undefined> = [undefined];
private bufferTableFreeId: Array<number> = [];
- private pendingRead: Promise<void> = Promise.resolve();
- private numPendingReads = 0;
+ private canvasRenderManager?: CanvaRenderManager = undefined;
constructor(memory: Memory, device: GPUDevice) {
this.memory = memory;
@@ -65,14 +278,66 @@ export class WebGPUContext {
* Wait for all pending GPU tasks to complete
*/
async sync(): Promise<void> {
- if (this.numPendingReads != 0) {
- await Promise.all([
- this.device.queue.onSubmittedWorkDone(),
- this.pendingRead
- ])
- } else {
- await this.device.queue.onSubmittedWorkDone()
+ await this.device.queue.onSubmittedWorkDone();
+ }
+
+ /**
+ * Dispose the binded canvas.
+ */
+ disposeCanvas() {
+ this.canvasRenderManager?.dispose();
+ this.canvasRenderManager = undefined;
+ }
+
+ /**
+ * Draw image from data in storage buffer.
+ * @param ptr The GPU ptr
+ * @param height The height of the image.
+ * @param width The width of the image.
+ */
+ drawImageFromBuffer(ptr: GPUPointer, height: number, width: number) {
+ if (this.canvasRenderManager == undefined) {
+ throw Error("Do not have a canvas context, call bindCanvas first");
}
+ this.canvasRenderManager.draw(this.gpuBufferFromPtr(ptr), height, width);
+ }
+
+ /**
+ * Copy raw bytes into buffer ptr.
+ *
+ * @param rawBytes The raw bytes
+ * @param toPtr The target gpu buffer ptr
+ * @param toOffset The beginning offset
+ * @param nbytes Number of bytes
+ */
+ copyRawBytesToBuffer(
+ rawBytes: Uint8Array,
+ toPtr: GPUPointer,
+ toOffset: number,
+ nbytes: number
+ ): void {
+ // Perhaps it would be more useful to use a staging buffer?
+ this.device.queue.writeBuffer(
+ this.gpuBufferFromPtr(toPtr),
+ toOffset,
+ rawBytes,
+ 0,
+ nbytes
+ );
+ }
+ /**
+ * Clear canvas
+ */
+ clearCanvas() {
+ this.canvasRenderManager?.clear();
+ }
+
+ /**
+ * Bind a canvas element to the runtime.
+ * @param canvas The HTML canvas/
+ */
+ bindCanvas(canvas: HTMLCanvasElement) {
+ this.canvasRenderManager = new CanvaRenderManager(this.device, canvas);
}
/**
@@ -83,6 +348,28 @@ export class WebGPUContext {
*/
createShader(info: string, code: string): Function {
const finfo = JSON.parse(info);
+ const dispatchToDim: Array<number> = [];
+ let paramWriteAccess: Array<number> = [];
+
+ for (let i = 0; i < finfo.launch_param_tags.length; ++i) {
+ const tag: string = finfo.launch_param_tags[i];
+ if (tag.startsWith("blockIdx.")) {
+ const target: number = tag.charCodeAt(tag.length - 1) -
("x".charCodeAt(0));
+ assert(target >= 0 && target < 3);
+ dispatchToDim.push(target);
+ } else if (tag.startsWith("threadIdx.")) {
+ const target: number = tag.charCodeAt(tag.length - 1) -
("x".charCodeAt(0));
+ assert(target >= 0 && target < 3);
+ dispatchToDim.push(target + 3);
+ } else if (tag.startsWith("paramWriteAccess:")) {
+ paramWriteAccess = JSON.parse(tag.substring(17));
+ } else {
+ throw new Error("Cannot handle thread_axis " + tag);
+ }
+ }
+
+ assert(paramWriteAccess.length == finfo.arg_types.length);
+
const layoutEntries: Array<GPUBindGroupLayoutEntry> = [];
for (let i = 0; i < finfo.arg_types.length; ++i) {
const dtype = finfo.arg_types[i];
@@ -91,7 +378,7 @@ export class WebGPUContext {
binding: i,
visibility: GPUShaderStage.COMPUTE,
buffer : {
- type: "storage"
+ type: paramWriteAccess[i] ? "storage" : "read-only-storage"
}
});
} else {
@@ -101,36 +388,25 @@ export class WebGPUContext {
const bindGroupLayout = this.device.createBindGroupLayout({
entries: layoutEntries
});
+ const pipelineLayout = this.device.createPipelineLayout({
+ bindGroupLayouts: [ bindGroupLayout ]
+ });
const pipeline = this.device.createComputePipeline({
- layout: this.device.createPipelineLayout({
- bindGroupLayouts: [ bindGroupLayout ]
- }),
+ layout: pipelineLayout,
compute: {
module: this.device.createShaderModule({
- code: code
+ code: code,
+ hints: {
+ main: {
+ layout: pipelineLayout
+ }
+ }
}),
entryPoint: "main"
}
});
- const dispatchToDim: Array<number> = [];
-
- for (let i = 0; i < finfo.launch_param_tags.length; ++i) {
- const tag: string = finfo.launch_param_tags[i];
- if (tag.startsWith("blockIdx.")) {
- const target: number = tag.charCodeAt(tag.length - 1) -
("x".charCodeAt(0));
- assert(target >= 0 && target < 3);
- dispatchToDim.push(target);
- } else if (tag.startsWith("threadIdx.")) {
- const target: number = tag.charCodeAt(tag.length - 1) -
("x".charCodeAt(0));
- assert(target >= 0 && target < 3);
- dispatchToDim.push(target + 3);
- } else {
- throw new Error("Cannot handle thread_axis " + tag);
- }
- }
-
const submitShader = (...args: Array<GPUPointer | number>): void => {
const commandEncoder = this.device.createCommandEncoder();
const compute = commandEncoder.beginComputePass();
@@ -155,6 +431,26 @@ export class WebGPUContext {
for (let i = 0; i < dispatchToDim.length; ++i) {
wl[dispatchToDim[i]] = args[layoutEntries.length + i];
}
+
+ // get around 65535 restriction of blockIdx.x
+ if (wl[2] != 1) {
+ throw Error("WebGPU: blockIdx.z is reserved for internal use");
+ }
+ // spread thinsg out into blockIdx.z
+ if (wl[0] >= (1 << 16)) {
+ let wl_x = wl[0];
+ let wl_z = wl[2];
+
+ while (wl_x >= (1 << 16)) {
+ if (wl_x % 2 != 0) {
+ throw Error("WebGPU: cannot factorize big gridDim.x=" +
wl[0].toString());
+ }
+ wl_x /= 2;
+ wl_z *= 2;
+ }
+ wl[0] = wl_x;
+ wl[2] = wl_z;
+ }
compute.dispatchWorkgroups(wl[0], wl[1], wl[2])
compute.end()
const command = commandEncoder.finish();
@@ -209,7 +505,6 @@ export class WebGPUContext {
} else {
throw new Error("Unknown DeviceAPI function " + name);
}
-
}
// DeviceAPI
@@ -218,7 +513,8 @@ export class WebGPUContext {
size: nbytes,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC |
GPUBufferUsage.COPY_DST,
});
- return this.attachToBufferTable(buffer);
+ const ptr = this.attachToBufferTable(buffer);
+ return ptr;
}
private deviceFreeDataSpace(ptr: GPUPointer): void {
@@ -237,29 +533,14 @@ export class WebGPUContext {
nbytes: number
): void {
// Perhaps it would be more useful to use a staging buffer?
- const gpuTemp = this.device.createBuffer({
- mappedAtCreation: true,
- size: nbytes,
- usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC
- });
-
- const cpuTemp = gpuTemp.getMappedRange();
-
- const viewU8 = new Uint8Array(cpuTemp);
- viewU8.set(this.memory.loadRawBytes(from, nbytes));
- gpuTemp.unmap();
-
- const copyEncoder = this.device.createCommandEncoder();
- copyEncoder.copyBufferToBuffer(
- gpuTemp,
- 0,
+ const rawBytes = this.memory.loadRawBytes(from, nbytes);
+ this.device.queue.writeBuffer(
this.gpuBufferFromPtr(to),
toOffset,
+ rawBytes,
+ 0,
nbytes
);
- const copyCommands = copyEncoder.finish();
- this.device.queue.submit([copyCommands]);
- gpuTemp.destroy();
}
private deviceCopyFromGPU(
@@ -285,24 +566,11 @@ export class WebGPUContext {
const copyCommands = copyEncoder.finish();
this.device.queue.submit([copyCommands]);
- this.numPendingReads += 1;
-
- const readEvent = gpuTemp.mapAsync(GPUMapMode.READ).then(() => {
+ gpuTemp.mapAsync(GPUMapMode.READ).then(() => {
const data = gpuTemp.getMappedRange();
this.memory.storeRawBytes(to, new Uint8Array(data));
- this.numPendingReads -= 1;
gpuTemp.destroy();
});
-
- if (this.numPendingReads == 1) {
- this.pendingRead = readEvent;
- } else {
- this.pendingRead = Promise.all([
- this.pendingRead,
- readEvent,
- // eslint-disable-next-line @typescript-eslint/no-empty-function
- ]).then(() => {});
- }
}
private deviceCopyWithinGPU(