This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 42c51ca6d3 [Unity][WEB] Improve webgpu codegen options to skip
readonly (#14213)
42c51ca6d3 is described below
commit 42c51ca6d3c24fdad05bd43a3fac2bd85a6fb3e8
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 6 15:25:15 2023 -0500
[Unity][WEB] Improve webgpu codegen options to skip readonly (#14213)
Readonly detection can cause the kernel arg order
to be different from other shaders, add options to
optionally skip it.
Also makes export auto use emcc for wasm target.
---
include/tvm/runtime/module.h | 4 ++++
python/tvm/runtime/module.py | 8 ++++++-
src/runtime/metal/metal_module.mm | 9 ++++++++
src/runtime/module.cc | 4 ++++
src/target/source/codegen_webgpu.cc | 43 +++++++++++++++++++++++++------------
src/target/source/codegen_webgpu.h | 2 +-
web/apps/browser/rpc_server.html | 3 ++-
web/src/rpc_server.ts | 10 ++++-----
web/src/runtime.ts | 8 +++----
web/src/webgpu.ts | 21 +++++++++++++++---
10 files changed, 82 insertions(+), 30 deletions(-)
diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index a54f98a558..ab7b50bbd0 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -189,6 +189,10 @@ class TVM_DLL ModuleNode : public Object {
* \return The corresponding function.
*/
const PackedFunc* GetFuncFromEnv(const std::string& name);
+
+ /*! \brief Clear all imports of the module. */
+ void ClearImports() { imports_.clear(); }
+
/*! \return The module it imports from */
const std::vector<Module>& imports() const { return imports_; }
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 83b436939e..db2b704a60 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -251,6 +251,10 @@ class Module(object):
"""
return _ffi_api.ModuleIsDSOExportable(self)
+ def clear_imports(self):
+ """Remove all imports of the module."""
+ _ffi_api.ModuleClearImports(self)
+
def save(self, file_name, fmt=""):
"""Save the module to file.
@@ -441,7 +445,7 @@ class Module(object):
raise RuntimeError("Cannot call export_library in runtime only
mode")
# Extra dependencies during runtime.
from pathlib import Path
- from tvm.contrib import cc as _cc, tar as _tar, utils as _utils
+ from tvm.contrib import cc as _cc, tar as _tar, utils as _utils, tvmjs
as _tvmjs
if isinstance(file_name, Path):
file_name = str(file_name)
@@ -506,6 +510,8 @@ class Module(object):
if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
+ elif file_name.endswith(".wasm"):
+ fcompile = _tvmjs.create_tvmjs_wasm
else:
fcompile = _cc.create_shared
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index 1e81ac1bbb..a89ef8b5ef 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -277,6 +277,15 @@ Module MetalModuleCreate(std::string data, std::string fmt,
return Module(n);
}
+TVM_REGISTER_GLOBAL("runtime.module.create_metal_module")
+ .set_body_typed([](std::string data, std::string fmap_json) {
+ std::istringstream stream(fmap_json);
+ std::unordered_map<std::string, FunctionInfo> fmap;
+ dmlc::JSONReader reader(&stream);
+ reader.Read(&fmap);
+ return MetalModuleCreate(data, "metal", fmap, "");
+ });
+
// Load module from module.
Module MetalModuleLoadFile(const std::string& file_name, const std::string&
format) {
std::string data;
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index 9ef57e9053..03aba9e8bf 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -186,6 +186,10 @@
TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int
return mod->imports().at(index);
});
+TVM_REGISTER_GLOBAL("runtime.ModuleClearImports").set_body_typed([](Module
mod) {
+ mod->ClearImports();
+});
+
TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
return std::string(mod->type_key());
});
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index ff9267ea7a..a4c2ba0b62 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -114,7 +114,7 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
-runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f) {
+runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool
skip_readonly_decl) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from
@@ -130,7 +130,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const
PrimFunc& f) {
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
decl_stream << "//----------------------------------------\n"
- << "// function: " << global_symbol.value() << "\n"
+ << "// Function: " << global_symbol.value() << "\n"
<< "//----------------------------------------\n";
runtime::FunctionInfo func_info;
func_info.name = global_symbol.value();
@@ -167,7 +167,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const
PrimFunc& f) {
if (num_buffer != 0) {
os_param_access << ",";
}
- if (info.write_access_set.count(arg)) {
+ if (skip_readonly_decl || info.write_access_set.count(arg)) {
access_mode = "read_write";
os_param_access << "1";
} else {
@@ -208,7 +208,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const
PrimFunc& f) {
// add to alloc buffer type.
// Function header.
- this->stream << "fn main(\n"
+ this->stream << "fn " << func_info.name << "(\n"
<< " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
<< " @builtin(num_workgroups) gridDim : vec3<u32>,\n"
<< " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
@@ -568,15 +568,21 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
- PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
- stream << "for (var ";
- stream << vid << " : ";
+
+ PrintIndent();
+ stream << "var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
- stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
+ stream << " = 0;\n";
+ PrintIndent();
+ stream << "loop {\n";
int for_scope = BeginScope();
+ PrintIndent();
+ stream << "if " << vid << " >= " << extent << " { break; }\n";
PrintStmt(op->body);
+ PrintIndent();
+ stream << vid << "++;\n";
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";
@@ -617,11 +623,17 @@ class WebGPUSourceModuleNode final : public
runtime::ModuleNode {
}
std::string GetSource(const std::string& format) final {
- std::ostringstream os;
- for (auto kv : smap_) {
- os << kv.second;
+ if (format == "func_info") {
+ std::ostringstream stream;
+ dmlc::JSONWriter(&stream).Write(fmap_);
+ return stream.str();
+ } else {
+ std::ostringstream os;
+ for (auto kv : smap_) {
+ os << kv.second;
+ }
+ return os.str();
}
- return os.str();
}
private:
@@ -637,10 +649,13 @@ class WebGPUSourceModuleNode final : public
runtime::ModuleNode {
runtime::Module BuildWebGPU(IRModule mod, Target target) {
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
bool output_ssa = false;
-
+ bool skip_readonly_decl = false;
std::unordered_map<std::string, std::string> smap;
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
+ // narrow all i64 to i32
+ mod = tir::transform::ForceNarrowIndexToInt32()(std::move(mod));
+
for (auto kv : mod->functions) {
CodeGenWebGPU cg(target);
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only
take PrimFunc";
@@ -653,7 +668,7 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol
attribute";
std::string f_name = global_symbol.value();
cg.Init(output_ssa);
- fmap[f_name] = cg.AddFunction(f);
+ fmap[f_name] = cg.AddFunction(f, skip_readonly_decl);
std::string code = cg.Finish();
smap[f_name] = code;
}
diff --git a/src/target/source/codegen_webgpu.h
b/src/target/source/codegen_webgpu.h
index 47f94091a1..ff99f4608a 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -48,7 +48,7 @@ class CodeGenWebGPU final : public CodeGenC {
explicit CodeGenWebGPU(Target target);
// overrides
std::string Finish() final;
- runtime::FunctionInfo AddFunction(const PrimFunc& f); // NOLINT(*)
+ runtime::FunctionInfo AddFunction(const PrimFunc& f, bool
skip_readonly_decl); // NOLINT(*)
void InitFuncState(const PrimFunc& f) final;
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html
index a03e290daa..becc92ee71 100644
--- a/web/apps/browser/rpc_server.html
+++ b/web/apps/browser/rpc_server.html
@@ -130,7 +130,8 @@
<button onclick="connectRPC()">Connect To Proxy</button>
<button onclick="clearLog()">Clear Log</button>
<div id="progress">
- <label id="progress-tracker-label"></div>
+ <label id="gpu-tracker-label"> </label><br>
+ <label id="progress-tracker-label"> </label> <br>
<progress id="progress-tracker-progress" max="100" value="100">
</progress>
</div>
<div id="includeRPCPlugin"></div>
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
index 58601230e8..960c0ae18b 100644
--- a/web/src/rpc_server.ts
+++ b/web/src/rpc_server.ts
@@ -19,7 +19,7 @@
import { SizeOf, ArgTypeCode } from "./ctypes";
import { assert, StringToUint8Array, Uint8ArrayToString } from "./support";
-import { detectGPUDevice } from "./webgpu";
+import { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu";
import * as compact from "./compact";
import * as runtime from "./runtime";
import { Disposable } from "./types";
@@ -272,11 +272,11 @@ export class RPCServer {
);
try {
- const gpuDevice: GPUDevice | undefined | null = await
detectGPUDevice();
- if (gpuDevice !== undefined && gpuDevice !== null) {
- const label = gpuDevice.label?.toString() || "WebGPU";
+ const output: GPUDeviceDetectOutput | undefined = await
detectGPUDevice();
+ if (output !== undefined) {
+ const label = "WebGPU: "+ output.adapterInfo.description;
this.log("Initialize GPU device: " + label);
- inst.initWebGPU(gpuDevice);
+ inst.initWebGPU(output.device);
} else {
this.log("Cannot find WebGPU device in the env");
}
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 1f3232c557..c5a9becc7f 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -1373,10 +1373,8 @@ export class Instance implements Disposable {
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB
fetched "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "%
completed, "
text += timeElapsed + " secs elapsed";
- if (timeElapsed != 0){
- text += ", speed=" + (fetchedBytes / timeElapsed / (1024 *
1024)).toFixed(1) + " MB/sec."
- }
- text += " This can take a while during first load.";
+ text += " It can take a while when we first visit this page to
populate the cache."
+ text += " Later refreshes will become faster.";
this.fetchProgressCallback[j]({
fetchedBytes: fetchedBytes,
totalBytes: totalBytes,
@@ -1391,7 +1389,7 @@ export class Instance implements Disposable {
fetchedBytes: 0,
totalBytes: totalBytes,
timeElapsed: 0,
- text: "Start to fetch " + ndarrayCacheUrl
+ text: "Start to fetch params",
});
}
const cache = await caches.open("tvmjs");
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index 8b5d2ee543..bc466a1543 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -25,18 +25,33 @@ import { Disposable } from "./types";
/** A pointer to points to the raw address space. */
export type GPUPointer = number;
+export interface GPUDeviceDetectOutput {
+ adapter: GPUAdapter;
+ adapterInfo: GPUAdapterInfo;
+ device: GPUDevice;
+}
+
/**
* DetectGPU device in the environment.
*/
-export async function detectGPUDevice(): Promise<GPUDevice | undefined | null>
{
+export async function detectGPUDevice(): Promise<GPUDeviceDetectOutput |
undefined> {
if (typeof navigator !== "undefined" && navigator.gpu !== undefined) {
const adapter = await navigator.gpu.requestAdapter();
- return await adapter?.requestDevice({
+ if (adapter == null) {
+ throw Error("Cannot find adapter that matches the request");
+ }
+ const adapterInfo = await adapter.requestAdapterInfo();
+ const device = await adapter.requestDevice({
requiredLimits: {
maxStorageBufferBindingSize: 1 << 30,
maxComputeWorkgroupStorageSize: 32 << 10,
}
});
+ return {
+ adapter: adapter,
+ adapterInfo: adapterInfo,
+ device: device
+ };
} else {
return undefined;
}
@@ -403,7 +418,7 @@ export class WebGPUContext {
}
}
}),
- entryPoint: "main"
+ entryPoint: finfo.name
}
});