This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cf1b5da9a6 [Unity][WEB] Simplify WebGPU Codegen per spec (#14225)
cf1b5da9a6 is described below
commit cf1b5da9a63257b174fbab580e3b634fa2d71304
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Mar 7 13:00:00 2023 -0500
[Unity][WEB] Simplify WebGPU Codegen per spec (#14225)
This PR simplfies webgpu codegen according to spec and also PR adds utils
to allow tvmjs to export runtime from a different TVM_HOME.
---
python/tvm/_ffi/libinfo.py | 12 +++++++++---
python/tvm/contrib/tvmjs.py | 30 ++++++++++++++++++++++++++++++
src/target/source/codegen_webgpu.cc | 37 +++----------------------------------
3 files changed, 42 insertions(+), 37 deletions(-)
diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py
index b27d4247fd..acaba17361 100644
--- a/python/tvm/_ffi/libinfo.py
+++ b/python/tvm/_ffi/libinfo.py
@@ -74,9 +74,15 @@ def get_dll_directories():
dll_path.append(install_lib_dir)
- if os.path.isdir(source_dir):
- dll_path.append(os.path.join(source_dir, "web", "dist", "wasm"))
- dll_path.append(os.path.join(source_dir, "web", "dist"))
+ # use extra TVM_HOME environment for finding libraries.
+ if os.environ.get("TVM_HOME", None):
+ tvm_source_home_dir = os.environ["TVM_HOME"]
+ else:
+ tvm_source_home_dir = source_dir
+
+ if os.path.isdir(tvm_source_home_dir):
+ dll_path.append(os.path.join(tvm_source_home_dir, "web", "dist",
"wasm"))
+ dll_path.append(os.path.join(tvm_source_home_dir, "web", "dist"))
dll_path = [os.path.realpath(x) for x in dll_path]
return [x for x in dll_path if os.path.isdir(x)]
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 49626e725d..f47cdc4dcb 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -19,11 +19,13 @@
import sys
import os
import json
+import shutil
from typing import Mapping, Union
import numpy as np
import tvm
+from tvm._ffi.libinfo import find_lib_path
from .emcc import create_tvmjs_wasm
@@ -170,3 +172,31 @@ def load_ndarray_cache(cachepath: str, device:
tvm.runtime.Device):
arr.copyfrom(data)
result_dict[name] = arr
return result_dict, json_info["metadata"]
+
+
+def export_runtime(runtime_dir):
+ """Export TVMJS runtime to the runtime_dir
+
+ Parameters
+ ----------
+ runtime_dir: str
+ The runtime directory
+ """
+ web_hint = (
+ "make sure you setup tvm web runtime correctly."
+ + " obtain a copy of TVM source code, set TVM_HOME env variable:\n"
+ + " cd /path/to/tvm/web; make; npm run bundle"
+ )
+
+ jsbundle = find_lib_path("tvmjs.bundle.js", optional=True)
+ if not jsbundle:
+ raise RuntimeError("Cannot find tvmjs.bundle.js, " + web_hint)
+
+ wasi = find_lib_path("tvmjs_runtime.wasi.js", optional=True)
+ if not wasi:
+ raise RuntimeError("Cannot find tvmjs_runtime.wasi.js, " + web_hint)
+
+ print(f"Copy {jsbundle[0]} to {runtime_dir}")
+ shutil.copy(jsbundle[0], runtime_dir)
+ print(f"Copy {wasi[0]} to {runtime_dir}")
+ shutil.copy(wasi[0], runtime_dir)
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index a4c2ba0b62..8ba2b4a65e 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -335,32 +335,8 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLIN
this->PrintExpr(EnforceU32(op->args[1]), os);
os << ')';
} else if (op->op.same_as(builtin::if_then_else())) {
+ // WebGPU will insert clamping in buffer access so no need to check OOB.
this->PrintExpr(Select(op->args[0], op->args[1], op->args[2]), os);
- return;
- // conditional that skips eval if cond evals to false
- std::string result = name_supply_->FreshName("condval");
- std::string cond = PrintExpr(op->args[0]);
- this->PrintIndent();
- this->stream << "var " << result << " : ";
- PrintType(op->dtype, this->stream);
- this->stream << ";\n";
- this->PrintIndent();
- this->stream << "if (" << cond << ") {\n";
- {
- int then_scope = this->BeginScope();
- std::string true_val = PrintExpr(op->args[1]);
- this->PrintIndent();
- this->stream << result << " = " << true_val << ";\n} else {\n";
- this->EndScope(then_scope);
- }
- {
- int else_scope = this->BeginScope();
- std::string false_val = PrintExpr(op->args[2]);
- this->PrintIndent();
- this->stream << result << " = " << false_val << ";\n}\n";
- this->EndScope(else_scope);
- }
- os << result;
} else {
CodeGenC::VisitExpr_(op, os);
}
@@ -570,19 +546,12 @@ void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
-
PrintIndent();
- stream << "var " << vid << " : ";
+ stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
- stream << " = 0;\n";
- PrintIndent();
- stream << "loop {\n";
+ stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
int for_scope = BeginScope();
- PrintIndent();
- stream << "if " << vid << " >= " << extent << " { break; }\n";
PrintStmt(op->body);
- PrintIndent();
- stream << vid << "++;\n";
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";