This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 42c51ca6d3 [Unity][WEB] Improve webgpu codegen options to skip 
readonly (#14213)
42c51ca6d3 is described below

commit 42c51ca6d3c24fdad05bd43a3fac2bd85a6fb3e8
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 6 15:25:15 2023 -0500

    [Unity][WEB] Improve webgpu codegen options to skip readonly (#14213)
    
    Readonly detection can cause the kernel arg order
    to be different from other shaders, add options to
    optionally skip it.
    
    Also makes export auto use emcc for wasm target.
---
 include/tvm/runtime/module.h        |  4 ++++
 python/tvm/runtime/module.py        |  8 ++++++-
 src/runtime/metal/metal_module.mm   |  9 ++++++++
 src/runtime/module.cc               |  4 ++++
 src/target/source/codegen_webgpu.cc | 43 +++++++++++++++++++++++++------------
 src/target/source/codegen_webgpu.h  |  2 +-
 web/apps/browser/rpc_server.html    |  3 ++-
 web/src/rpc_server.ts               | 10 ++++-----
 web/src/runtime.ts                  |  8 +++----
 web/src/webgpu.ts                   | 21 +++++++++++++++---
 10 files changed, 82 insertions(+), 30 deletions(-)

diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index a54f98a558..ab7b50bbd0 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -189,6 +189,10 @@ class TVM_DLL ModuleNode : public Object {
    * \return The corresponding function.
    */
   const PackedFunc* GetFuncFromEnv(const std::string& name);
+
+  /*! \brief Clear all imports of the module. */
+  void ClearImports() { imports_.clear(); }
+
   /*! \return The module it imports from */
   const std::vector<Module>& imports() const { return imports_; }
 
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 83b436939e..db2b704a60 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -251,6 +251,10 @@ class Module(object):
         """
         return _ffi_api.ModuleIsDSOExportable(self)
 
+    def clear_imports(self):
+        """Remove all imports of the module."""
+        _ffi_api.ModuleClearImports(self)
+
     def save(self, file_name, fmt=""):
         """Save the module to file.
 
@@ -441,7 +445,7 @@ class Module(object):
             raise RuntimeError("Cannot call export_library in runtime only 
mode")
         # Extra dependencies during runtime.
         from pathlib import Path
-        from tvm.contrib import cc as _cc, tar as _tar, utils as _utils
+        from tvm.contrib import cc as _cc, tar as _tar, utils as _utils, tvmjs 
as _tvmjs
 
         if isinstance(file_name, Path):
             file_name = str(file_name)
@@ -506,6 +510,8 @@ class Module(object):
         if not fcompile:
             if file_name.endswith(".tar"):
                 fcompile = _tar.tar
+            elif file_name.endswith(".wasm"):
+                fcompile = _tvmjs.create_tvmjs_wasm
             else:
                 fcompile = _cc.create_shared
 
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index 1e81ac1bbb..a89ef8b5ef 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -277,6 +277,15 @@ Module MetalModuleCreate(std::string data, std::string fmt,
   return Module(n);
 }
 
+TVM_REGISTER_GLOBAL("runtime.module.create_metal_module")
+    .set_body_typed([](std::string data, std::string fmap_json) {
+      std::istringstream stream(fmap_json);
+      std::unordered_map<std::string, FunctionInfo> fmap;
+      dmlc::JSONReader reader(&stream);
+      reader.Read(&fmap);
+      return MetalModuleCreate(data, "metal", fmap, "");
+    });
+
 // Load module from module.
 Module MetalModuleLoadFile(const std::string& file_name, const std::string& 
format) {
   std::string data;
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index 9ef57e9053..03aba9e8bf 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -186,6 +186,10 @@ 
TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int
   return mod->imports().at(index);
 });
 
+TVM_REGISTER_GLOBAL("runtime.ModuleClearImports").set_body_typed([](Module 
mod) {
+  mod->ClearImports();
+});
+
 TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
   return std::string(mod->type_key());
 });
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index ff9267ea7a..a4c2ba0b62 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -114,7 +114,7 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
 
 CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
 
-runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f) {
+runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool 
skip_readonly_decl) {
   // clear previous generated state.
   this->InitFuncState(f);
   // skip the first underscore, so SSA variable starts from
@@ -130,7 +130,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const 
PrimFunc& f) {
       << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
 
   decl_stream << "//----------------------------------------\n"
-              << "// function: " << global_symbol.value() << "\n"
+              << "// Function: " << global_symbol.value() << "\n"
               << "//----------------------------------------\n";
   runtime::FunctionInfo func_info;
   func_info.name = global_symbol.value();
@@ -167,7 +167,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const 
PrimFunc& f) {
       if (num_buffer != 0) {
         os_param_access << ",";
       }
-      if (info.write_access_set.count(arg)) {
+      if (skip_readonly_decl || info.write_access_set.count(arg)) {
         access_mode = "read_write";
         os_param_access << "1";
       } else {
@@ -208,7 +208,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const 
PrimFunc& f) {
 
   // add to alloc buffer type.
   // Function header.
-  this->stream << "fn main(\n"
+  this->stream << "fn " << func_info.name << "(\n"
                << "  @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
                << "  @builtin(num_workgroups) gridDim : vec3<u32>,\n"
                << "  @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
@@ -568,15 +568,21 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
 
 void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
   std::string extent = PrintExpr(op->extent);
-  PrintIndent();
   std::string vid = AllocVarID(op->loop_var.get());
   ICHECK(is_zero(op->min));
-  stream << "for (var ";
-  stream << vid << " : ";
+
+  PrintIndent();
+  stream << "var " << vid << " : ";
   PrintType(op->loop_var.dtype(), stream);
-  stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
+  stream << " = 0;\n";
+  PrintIndent();
+  stream << "loop {\n";
   int for_scope = BeginScope();
+  PrintIndent();
+  stream << "if " << vid << " >= " << extent << " { break; }\n";
   PrintStmt(op->body);
+  PrintIndent();
+  stream << vid << "++;\n";
   this->EndScope(for_scope);
   PrintIndent();
   stream << "}\n";
@@ -617,11 +623,17 @@ class WebGPUSourceModuleNode final : public 
runtime::ModuleNode {
   }
 
   std::string GetSource(const std::string& format) final {
-    std::ostringstream os;
-    for (auto kv : smap_) {
-      os << kv.second;
+    if (format == "func_info") {
+      std::ostringstream stream;
+      dmlc::JSONWriter(&stream).Write(fmap_);
+      return stream.str();
+    } else {
+      std::ostringstream os;
+      for (auto kv : smap_) {
+        os << kv.second;
+      }
+      return os.str();
     }
-    return os.str();
   }
 
  private:
@@ -637,10 +649,13 @@ class WebGPUSourceModuleNode final : public 
runtime::ModuleNode {
 runtime::Module BuildWebGPU(IRModule mod, Target target) {
   mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
   bool output_ssa = false;
-
+  bool skip_readonly_decl = false;
   std::unordered_map<std::string, std::string> smap;
   std::unordered_map<std::string, runtime::FunctionInfo> fmap;
 
+  // narrow all i64 to i32
+  mod = tir::transform::ForceNarrowIndexToInt32()(std::move(mod));
+
   for (auto kv : mod->functions) {
     CodeGenWebGPU cg(target);
     ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only 
take PrimFunc";
@@ -653,7 +668,7 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
         << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol 
attribute";
     std::string f_name = global_symbol.value();
     cg.Init(output_ssa);
-    fmap[f_name] = cg.AddFunction(f);
+    fmap[f_name] = cg.AddFunction(f, skip_readonly_decl);
     std::string code = cg.Finish();
     smap[f_name] = code;
   }
diff --git a/src/target/source/codegen_webgpu.h 
b/src/target/source/codegen_webgpu.h
index 47f94091a1..ff99f4608a 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -48,7 +48,7 @@ class CodeGenWebGPU final : public CodeGenC {
   explicit CodeGenWebGPU(Target target);
   // overrides
   std::string Finish() final;
-  runtime::FunctionInfo AddFunction(const PrimFunc& f);  // NOLINT(*)
+  runtime::FunctionInfo AddFunction(const PrimFunc& f, bool 
skip_readonly_decl);  // NOLINT(*)
   void InitFuncState(const PrimFunc& f) final;
   void PrintStorageSync(const CallNode* op) final;     // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final;  // NOLINT(*)
diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html
index a03e290daa..becc92ee71 100644
--- a/web/apps/browser/rpc_server.html
+++ b/web/apps/browser/rpc_server.html
@@ -130,7 +130,8 @@
     <button onclick="connectRPC()">Connect To Proxy</button>
     <button onclick="clearLog()">Clear Log</button>
     <div id="progress">
-      <label id="progress-tracker-label"></div>
+      <label id="gpu-tracker-label"> </label><br>
+      <label id="progress-tracker-label"> </label> <br>
       <progress id="progress-tracker-progress" max="100" value="100"> 
</progress>
     </div>
     <div id="includeRPCPlugin"></div>
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
index 58601230e8..960c0ae18b 100644
--- a/web/src/rpc_server.ts
+++ b/web/src/rpc_server.ts
@@ -19,7 +19,7 @@
 
 import { SizeOf, ArgTypeCode } from "./ctypes";
 import { assert, StringToUint8Array, Uint8ArrayToString } from "./support";
-import { detectGPUDevice } from "./webgpu";
+import { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu";
 import * as compact from "./compact";
 import * as runtime from "./runtime";
 import { Disposable } from "./types";
@@ -272,11 +272,11 @@ export class RPCServer {
       );
 
       try {
-        const gpuDevice: GPUDevice | undefined | null = await 
detectGPUDevice();
-        if (gpuDevice !== undefined && gpuDevice !== null) {
-          const label = gpuDevice.label?.toString() || "WebGPU";
+        const output: GPUDeviceDetectOutput | undefined = await 
detectGPUDevice();
+        if (output !== undefined) {
+          const label = "WebGPU: "+ output.adapterInfo.description;
           this.log("Initialize GPU device: " + label);
-          inst.initWebGPU(gpuDevice);
+          inst.initWebGPU(output.device);
         } else {
           this.log("Cannot find WebGPU device in the env");
         }
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 1f3232c557..c5a9becc7f 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -1373,10 +1373,8 @@ export class Instance implements Disposable {
         text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB 
fetched "
         text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% 
completed, "
         text += timeElapsed + " secs elapsed";
-        if (timeElapsed != 0){
-          text += ", speed=" + (fetchedBytes / timeElapsed / (1024 * 
1024)).toFixed(1) + " MB/sec."
-        }
-        text += " This can take a while during first load.";
+        text += " It can take a while when we first visit this page to 
populate the cache."
+        text += " Later refreshes will become faster.";
         this.fetchProgressCallback[j]({
           fetchedBytes: fetchedBytes,
           totalBytes: totalBytes,
@@ -1391,7 +1389,7 @@ export class Instance implements Disposable {
         fetchedBytes: 0,
         totalBytes: totalBytes,
         timeElapsed: 0,
-        text: "Start to fetch " + ndarrayCacheUrl
+        text: "Start to fetch params",
       });
     }
     const cache = await caches.open("tvmjs");
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index 8b5d2ee543..bc466a1543 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -25,18 +25,33 @@ import { Disposable } from "./types";
 /** A pointer to points to the raw address space. */
 export type GPUPointer = number;
 
+export interface GPUDeviceDetectOutput {
+  adapter: GPUAdapter;
+  adapterInfo: GPUAdapterInfo;
+  device: GPUDevice;
+}
+
 /**
  * DetectGPU device in the environment.
  */
-export async function detectGPUDevice(): Promise<GPUDevice | undefined | null> 
{
+export async function detectGPUDevice(): Promise<GPUDeviceDetectOutput | 
undefined> {
   if (typeof navigator !== "undefined" && navigator.gpu !== undefined) {
     const adapter = await navigator.gpu.requestAdapter();
-    return await adapter?.requestDevice({
+    if (adapter == null) {
+      throw Error("Cannot find adapter that matches the request");
+    }
+    const adapterInfo = await adapter.requestAdapterInfo();
+    const device = await adapter.requestDevice({
       requiredLimits: {
         maxStorageBufferBindingSize: 1 << 30,
         maxComputeWorkgroupStorageSize: 32 << 10,
       }
     });
+    return {
+      adapter: adapter,
+      adapterInfo: adapterInfo,
+      device: device
+    };
   } else {
     return undefined;
   }
@@ -403,7 +418,7 @@ export class WebGPUContext {
             }
           }
         }),
-        entryPoint: "main"
+        entryPoint: finfo.name
       }
     });
 

Reply via email to