This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cf1b5da9a6 [Unity][WEB] Simplify WebGPU Codegen per spec (#14225)
cf1b5da9a6 is described below

commit cf1b5da9a63257b174fbab580e3b634fa2d71304
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Mar 7 13:00:00 2023 -0500

    [Unity][WEB] Simplify WebGPU Codegen per spec (#14225)
    
    This PR simplfies webgpu codegen according to spec and also PR adds utils 
to allow tvmjs to export runtime from a different TVM_HOME.
---
 python/tvm/_ffi/libinfo.py          | 12 +++++++++---
 python/tvm/contrib/tvmjs.py         | 30 ++++++++++++++++++++++++++++++
 src/target/source/codegen_webgpu.cc | 37 +++----------------------------------
 3 files changed, 42 insertions(+), 37 deletions(-)

diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py
index b27d4247fd..acaba17361 100644
--- a/python/tvm/_ffi/libinfo.py
+++ b/python/tvm/_ffi/libinfo.py
@@ -74,9 +74,15 @@ def get_dll_directories():
 
     dll_path.append(install_lib_dir)
 
-    if os.path.isdir(source_dir):
-        dll_path.append(os.path.join(source_dir, "web", "dist", "wasm"))
-        dll_path.append(os.path.join(source_dir, "web", "dist"))
+    # use extra TVM_HOME environment for finding libraries.
+    if os.environ.get("TVM_HOME", None):
+        tvm_source_home_dir = os.environ["TVM_HOME"]
+    else:
+        tvm_source_home_dir = source_dir
+
+    if os.path.isdir(tvm_source_home_dir):
+        dll_path.append(os.path.join(tvm_source_home_dir, "web", "dist", 
"wasm"))
+        dll_path.append(os.path.join(tvm_source_home_dir, "web", "dist"))
 
     dll_path = [os.path.realpath(x) for x in dll_path]
     return [x for x in dll_path if os.path.isdir(x)]
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 49626e725d..f47cdc4dcb 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -19,11 +19,13 @@
 import sys
 import os
 import json
+import shutil
 from typing import Mapping, Union
 
 import numpy as np
 
 import tvm
+from tvm._ffi.libinfo import find_lib_path
 from .emcc import create_tvmjs_wasm
 
 
@@ -170,3 +172,31 @@ def load_ndarray_cache(cachepath: str, device: 
tvm.runtime.Device):
             arr.copyfrom(data)
         result_dict[name] = arr
     return result_dict, json_info["metadata"]
+
+
+def export_runtime(runtime_dir):
+    """Export TVMJS runtime to the runtime_dir
+
+    Parameters
+    ----------
+    runtime_dir: str
+        The runtime directory
+    """
+    web_hint = (
+        "make sure you setup tvm web runtime correctly."
+        + " obtain a copy of TVM source code, set TVM_HOME env variable:\n"
+        + " cd /path/to/tvm/web; make; npm run bundle"
+    )
+
+    jsbundle = find_lib_path("tvmjs.bundle.js", optional=True)
+    if not jsbundle:
+        raise RuntimeError("Cannot find tvmjs.bundle.js, " + web_hint)
+
+    wasi = find_lib_path("tvmjs_runtime.wasi.js", optional=True)
+    if not wasi:
+        raise RuntimeError("Cannot find tvmjs_runtime.wasi.js, " + web_hint)
+
+    print(f"Copy {jsbundle[0]} to {runtime_dir}")
+    shutil.copy(jsbundle[0], runtime_dir)
+    print(f"Copy {wasi[0]} to {runtime_dir}")
+    shutil.copy(wasi[0], runtime_dir)
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index a4c2ba0b62..8ba2b4a65e 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -335,32 +335,8 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLIN
     this->PrintExpr(EnforceU32(op->args[1]), os);
     os << ')';
   } else if (op->op.same_as(builtin::if_then_else())) {
+    // WebGPU will insert clamping in buffer access so no need to check OOB.
     this->PrintExpr(Select(op->args[0], op->args[1], op->args[2]), os);
-    return;
-    // conditional that skips eval if cond evals to false
-    std::string result = name_supply_->FreshName("condval");
-    std::string cond = PrintExpr(op->args[0]);
-    this->PrintIndent();
-    this->stream << "var " << result << " : ";
-    PrintType(op->dtype, this->stream);
-    this->stream << ";\n";
-    this->PrintIndent();
-    this->stream << "if (" << cond << ") {\n";
-    {
-      int then_scope = this->BeginScope();
-      std::string true_val = PrintExpr(op->args[1]);
-      this->PrintIndent();
-      this->stream << result << " = " << true_val << ";\n} else {\n";
-      this->EndScope(then_scope);
-    }
-    {
-      int else_scope = this->BeginScope();
-      std::string false_val = PrintExpr(op->args[2]);
-      this->PrintIndent();
-      this->stream << result << " = " << false_val << ";\n}\n";
-      this->EndScope(else_scope);
-    }
-    os << result;
   } else {
     CodeGenC::VisitExpr_(op, os);
   }
@@ -570,19 +546,12 @@ void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
   std::string extent = PrintExpr(op->extent);
   std::string vid = AllocVarID(op->loop_var.get());
   ICHECK(is_zero(op->min));
-
   PrintIndent();
-  stream << "var " << vid << " : ";
+  stream << "for (var " << vid << " : ";
   PrintType(op->loop_var.dtype(), stream);
-  stream << " = 0;\n";
-  PrintIndent();
-  stream << "loop {\n";
+  stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
   int for_scope = BeginScope();
-  PrintIndent();
-  stream << "if " << vid << " >= " << extent << " { break; }\n";
   PrintStmt(op->body);
-  PrintIndent();
-  stream << vid << "++;\n";
   this->EndScope(for_scope);
   PrintIndent();
   stream << "}\n";

Reply via email to