This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 89d539d746 [Unity] Update specific builtins for LM (#14617)
89d539d746 is described below

commit 89d539d746f8ba0133a72cd9f74626a8b2013ce5
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Apr 14 15:56:16 2023 -0400

    [Unity] Update specific builtins for LM (#14617)
    
    This PR updates the specific builtins for LM
    and move them to lm_support.cc
    
    The kv_create now takes an initial data and copies it instead of consumes 
it.
    This will enable us to create kv within a VM more easily.
---
 python/tvm/exec/rpc_proxy.py                       |  2 +-
 src/runtime/relax_vm/builtin.cc                    |  2 +
 .../{attention_kv_cache.cc => lm_support.cc}       | 88 ++++++++++++++++++++--
 src/target/source/codegen_webgpu.cc                |  5 ++
 tests/python/relax/test_pipeline.py                | 19 ++++-
 tests/python/relax/test_runtime_builtin.py         |  2 +-
 web/emcc/wasm_runtime.cc                           |  2 +-
 web/src/runtime.ts                                 | 73 ++++++++++++++----
 8 files changed, 167 insertions(+), 26 deletions(-)

diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py
index eaa406906a..343588a400 100644
--- a/python/tvm/exec/rpc_proxy.py
+++ b/python/tvm/exec/rpc_proxy.py
@@ -35,7 +35,7 @@ def find_example_resource():
         ("/", os.path.join(base_path, "web", "dist", "wasm", 
"tvmjs_runtime.wasi.js")),
         ("/", index_page),
     ]
-    allow_format = ("json", "bin", "js", "wasm", "html", "css")
+    allow_format = ("json", "bin", "js", "wasm", "html", "css", "model")
 
     # recursively apend things in www, up to two levels
     resource_bases = [
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 5a7c1d6620..af0963bf41 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -57,6 +57,8 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) {
   // TODO(relax-team): visit and consider other possible choices.
   if (vm->devices[0].device_type == kDLHexagon) {
     host_device_index = 0;
+  } else {
+    ICHECK_EQ(vm->devices[host_device_index].device_type, kDLCPU);
   }
   auto* alloc = vm->allocators[host_device_index];
   return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, 
vm->devices[host_device_index]);
diff --git a/src/runtime/relax_vm/attention_kv_cache.cc 
b/src/runtime/relax_vm/lm_support.cc
similarity index 66%
rename from src/runtime/relax_vm/attention_kv_cache.cc
rename to src/runtime/relax_vm/lm_support.cc
index f94b233975..9470b86e1a 100644
--- a/src/runtime/relax_vm/attention_kv_cache.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -17,8 +17,10 @@
  * under the License.
  */
 /*!
- * \file src/runtime/relax_vm/attention_kv_cache.cc
- * \brief A simple implementation of inplace attention kv cache for runtime.
+ * \file src/runtime/relax_vm/lm_support.cc
+ * \brief Runtime to support language model related task
+ *
+ * Including inplace attention kv cache for runtime and simple sampler.
  *
  * This file provides a simple implementation of inplace attention
  * kv cache for relax runtime. The main goal here is to help us enable
@@ -33,7 +35,6 @@
  *
  * We can evolve this implementation as we build more LM verticals.
  */
-
 #include <tvm/runtime/container/shape_tuple.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/logging.h>
@@ -41,6 +42,8 @@
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/relax_vm/vm.h>
 
+#include <cmath>
+
 namespace tvm {
 namespace runtime {
 namespace relax_vm {
@@ -79,6 +82,9 @@ class AttentionKVCacheObj : public Object {
     return data.CreateView(shape, data->dtype);
   }
 
+  /** Clear the cache */
+  void Clear() { this->fill_count = 0; }
+
   /*!
    * \brief Append value to the cache.
    * \param value The value to be appended.
@@ -124,12 +130,16 @@ class AttentionKVCache : public ObjectRef {
  public:
   /*!
    * \brief Create the attention kv cache.
-   * \param init_reserve The initial reserved.
+   * \param init_data The initial reserved.
    */
-  static AttentionKVCache Create(NDArray init_data) {
+  static AttentionKVCache Create(NDArray init_data, ShapeTuple reserve_shape, 
int init_fill_count) {
     auto n = make_object<AttentionKVCacheObj>();
-    n->data = std::move(init_data);
+    n->data = NDArray::Empty(reserve_shape, init_data->dtype, 
init_data->device);
     n->fill_count = 0;
+    n->Append(init_data);
+    if (init_fill_count >= 0) {
+      n->fill_count = init_fill_count;
+    }
     return AttentionKVCache(n);
   }
 
@@ -157,6 +167,72 @@ NDArray AttentionKVCacheView(AttentionKVCache cache, 
ShapeTuple shape) {
 
 
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view").set_body_typed(AttentionKVCacheView);
 
+void AttentionKVCacheArrayClear(Array<AttentionKVCache> caches) {
+  for (AttentionKVCache cache : caches) {
+    cache->Clear();
+  }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_clear")
+    .set_body_typed(AttentionKVCacheArrayClear);
+
+// NOTE this is a built-in highly related to LM so we put it here.
+int SampleTopPFromLogits(NDArray logits, double temperature, double top_p, 
double uniform_sample) {
+  ICHECK(logits.IsContiguous());
+  ICHECK(logits.DataType() == DataType::Float(32));
+
+  if (logits->device.device_type != kDLCPU) {
+    logits = logits.CopyTo(DLDevice{kDLCPU, 0});
+  }
+
+  ICHECK(logits->device.device_type == kDLCPU);
+
+  for (int i = 0; i < logits->ndim - 1; ++i) {
+    ICHECK_EQ(logits->shape[i], 1) << "The leading dimensions of logits must 
be 1";
+  }
+
+  std::vector<std::pair<float, int>> data;
+  data.resize(logits->shape[logits->ndim - 1]);
+  const float* plogits = static_cast<float*>(logits->data);
+  for (size_t i = 0; i < data.size(); ++i) {
+    data[i] = std::make_pair(plogits[i], static_cast<int>(i));
+  }
+  // sort by logits from smallest to largest
+  std::sort(data.begin(), data.end());
+  float max_value = data.back().first;
+  // argmax
+  if (temperature < 1e-6f) {
+    return data.back().second;
+  }
+  // compute expf
+  float sum = 0.0f;
+  for (size_t i = 0; i < data.size(); ++i) {
+    data[i].first = expf(data[i].first - max_value);
+    sum += data[i].first;
+  }
+  // do a cumsum in order of data
+  float cum_sum_prob = 0.0f;
+  float top_p_sum = 0.0f;
+  for (auto rit = data.rbegin(); rit != data.rend(); ++rit) {
+    float prob = rit->first / sum;
+    if (cum_sum_prob < top_p) {
+      top_p_sum += prob;
+    }
+    cum_sum_prob += prob;
+    rit->first = cum_sum_prob;
+  }
+  // pick a number based on random in (0, 1)
+  for (auto rit = data.rbegin(); rit != data.rend(); ++rit) {
+    if (uniform_sample < rit->first / top_p_sum) {
+      return rit->second;
+    }
+  }
+  ICHECK_LE(uniform_sample, data[0].first);
+  return data[0].second;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits);
+
 }  // namespace relax_vm
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index d56a5d547f..95c46b8894 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -117,6 +117,11 @@ CodeGenWebGPU::CodeGenWebGPU(Target target) : 
target_(target) {}
 runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool 
skip_readonly_decl) {
   // clear previous generated state.
   this->InitFuncState(f);
+  // reserve keywords
+  name_supply_->ReserveName("var");
+  name_supply_->ReserveName("let");
+  name_supply_->ReserveName("const");
+
   // skip the first underscore, so SSA variable starts from
   name_supply_->FreshName("v_");
   // Setup the thread group info.
diff --git a/tests/python/relax/test_pipeline.py 
b/tests/python/relax/test_pipeline.py
index 3c97d6b701..2dac42b334 100644
--- a/tests/python/relax/test_pipeline.py
+++ b/tests/python/relax/test_pipeline.py
@@ -52,6 +52,20 @@ def test_pipeline_with_kv_cache():
 
     @tvm.script.ir_module
     class Mod:
+        @R.function
+        def create_kv_cache(reserve_slots: R.Shape(["m"])):
+            # just allocate minimum slot since it is only used to signal dtype
+            m = T.int64()
+            init_data = R.ones((1, 4), "float32")
+            kv_cache = R.call_packed(
+                "vm.builtin.attention_kv_cache_create",
+                init_data,
+                R.shape([m, 4]),
+                0,
+                sinfo_args=[R.Object],
+            )
+            return kv_cache
+
         @R.function
         def main(
             x: R.Tensor((1, 4), "float32"),
@@ -84,11 +98,10 @@ def test_pipeline_with_kv_cache():
 
     num_steps = 8
     cache_np = np.empty((num_steps, 4), dtype="float32")
+    vm = relax.VirtualMachine(ex, tvm.cpu())
 
-    fcreate_cache = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
-    kv_cache = fcreate_cache(tvm.nd.empty((2, 4), device=tvm.cpu(), 
dtype="float32"))
+    kv_cache = vm["create_kv_cache"](tvm.runtime.ShapeTuple([1]))
 
-    vm = relax.VirtualMachine(ex, tvm.cpu())
     for i in range(num_steps):
         x_np = np.random.rand(1, 4).astype(np.float32)
         y_np = np.random.rand(1, 4).astype(np.float32)
diff --git a/tests/python/relax/test_runtime_builtin.py 
b/tests/python/relax/test_runtime_builtin.py
index bb513eb357..6ba06d0693 100644
--- a/tests/python/relax/test_runtime_builtin.py
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -155,7 +155,7 @@ def test_attention_kv_cache():
     fappend = tvm.get_global_func("vm.builtin.attention_kv_cache_append")
     fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
 
-    cache = fcreate(tvm.nd.empty((2, 2), dtype="int32"))
+    cache = fcreate(tvm.nd.empty((1, 2), dtype="int32"), 
tvm.runtime.ShapeTuple([2, 2]), 0)
     num_steps = 0
     for i in range(num_steps):
         cache = fappend(cache, tvm.nd.array(i * np.ones((1, 
2).astype("int32"))))
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index 9b2b6c180c..2412fb8d2d 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -53,10 +53,10 @@
 #include "src/runtime/system_library.cc"
 #include "src/runtime/workspace_pool.cc"
 // relax setup
-#include "src/runtime/relax_vm/attention_kv_cache.cc"
 #include "src/runtime/relax_vm/builtin.cc"
 #include "src/runtime/relax_vm/bytecode.cc"
 #include "src/runtime/relax_vm/executable.cc"
+#include "src/runtime/relax_vm/lm_support.cc"
 #include "src/runtime/relax_vm/memory_manager.cc"
 #include "src/runtime/relax_vm/vm.cc"
 
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 427bb2d8f8..26ba1ecd67 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -143,13 +143,16 @@ class RuntimeContext implements Disposable {
   arrayGetItem : PackedFunc;
   arrayGetSize : PackedFunc;
   arrayMake : PackedFunc;
-  getSysLib: PackedFunc;
-  arrayCacheGet: PackedFunc;
-  arrayCacheUpdate: PackedFunc;
-  arrayCacheRemove: PackedFunc;
-  arrayCacheClear: PackedFunc;
-  arrayDecodeStorage: PackedFunc;
-  paramModuleFromCache: PackedFunc;
+  getSysLib : PackedFunc;
+  arrayCacheGet : PackedFunc;
+  arrayCacheUpdate : PackedFunc;
+  arrayCacheRemove : PackedFunc;
+  arrayCacheClear : PackedFunc;
+  arrayDecodeStorage : PackedFunc;
+  paramModuleFromCache : PackedFunc;
+  makeShapeTuple : PackedFunc;
+  ndarrayCreateView : PackedFunc;
+  sampleTopPFromLogits : PackedFunc;
 
   private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
 
@@ -164,12 +167,14 @@ class RuntimeContext implements Disposable {
     this.arrayCacheClear = getGlobalFunc("tvmjs.ndarray_cache.clear");
     this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
     this.paramModuleFromCache = getGlobalFunc("tvmjs.param_module_from_cache");
-
+    this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
+    this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
+    this.sampleTopPFromLogits = 
getGlobalFunc("vm.builtin.sample_top_p_from_logits");
   }
 
   dispose(): void {
     // call array cache clear to clear all cached items
-    this.arrayCacheClear();
+    this.arrayCacheClear.dispose();
     this.arrayGetItem.dispose();
     this.arrayGetSize.dispose();
     this.arrayMake.dispose();
@@ -179,6 +184,9 @@ class RuntimeContext implements Disposable {
     this.arrayCacheClear.dispose();
     this.arrayDecodeStorage.dispose();
     this.paramModuleFromCache.dispose();
+    this.makeShapeTuple.dispose();
+    this.ndarrayCreateView.dispose();
+    this.sampleTopPFromLogits.dispose();
   }
 
   beginScope() : void {
@@ -419,12 +427,14 @@ export class NDArray implements Disposable {
   private dltensor: Pointer;
   private dataPtr: Pointer;
   private lib: FFILibrary;
+  private ctx: RuntimeContext;
   private dlDataType: DLDataType;
 
-  constructor(handle: Pointer, isView: boolean, lib: FFILibrary) {
+  constructor(handle: Pointer, isView: boolean, lib: FFILibrary, ctx: 
RuntimeContext) {
     this.handle = handle;
     this.isView = isView;
     this.lib = lib;
+    this.ctx = ctx;
 
     if (this.isView) {
       this.dltensor = handle;
@@ -470,6 +480,16 @@ export class NDArray implements Disposable {
     this.byteOffset = lib.memory.loadI64(this.dltensor + 
arrayOffsetByteOffset);
   }
 
+  /**
+   * Create a view of the array.
+   * @param shape The shape of the view.
+   * @returns The new sliced ndarray.
+   */
+  view(shape: Array<number>) : NDArray {
+    const shapeArray = shape.map((value) => new Scalar(value, "int"));
+    return this.ctx.ndarrayCreateView(this, 
this.ctx.makeShapeTuple(...shapeArray));
+  }
+
  /**
   * Get handle of ndarray, check it is not null.
   *
@@ -870,7 +890,11 @@ export class VirtualMachine implements Disposable {
     this.mod.getFunction("vm_initialization")(
       new Scalar(device.deviceType, "int"),
       new Scalar(device.deviceId, "int"),
-      new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int")
+      new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"),
+      // explicitly specify host device type
+      new Scalar(DeviceStrToEnum.cpu, "int"),
+      new Scalar(0, "int"),
+      new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"),
     );
   }
 
@@ -1581,7 +1605,7 @@ export class Instance implements Disposable {
       )
     );
     const ret = this.ctx.attachToCurrentScope(
-      new NDArray(this.memory.loadPointer(outPtr), false, this.lib)
+      new NDArray(this.memory.loadPointer(outPtr), false, this.lib, this.ctx)
     );
     this.lib.recycleCallStack(stack);
     return ret;
@@ -1614,6 +1638,18 @@ export class Instance implements Disposable {
     return ret.copyFrom(input);
   }
 
+  /**
+   * Sample index via top-p sampling.
+   *
+   * @param logits The input logits before normalization.
+   * @param temperature  The temperature factor, will take argmax if 
temperature = 0.0
+   * @param top_p The top_p
+   * @returns The sampled index.
+   */
+  sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number): 
number {
+    return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, 
Math.random());
+  }
+
   /**
    * Bind canvas to the current WebGPU context
    * @param canvas The canvas.
@@ -1668,6 +1704,15 @@ export class Instance implements Disposable {
     return this.ctx.arrayMake(...inputs) as TVMArray;
   }
 
+  /**
+   * Create a shape tuple to pass to runtime.
+   * @param shape The shape .
+   * @returns The created shape tuple.
+   */
+  makeShapeTuple(shape: Array<number>) : TVMObject {
+    const shapeArray = shape.map((value) => new Scalar(value, "int"));
+    return this.ctx.makeShapeTuple(...shapeArray);
+  }
   /**
    * Get type index from type key.
    * @param typeKey The type key.
@@ -2131,13 +2176,13 @@ export class Instance implements Disposable {
         }
       case ArgTypeCode.TVMNDArrayHandle: {
         return this.ctx.attachToCurrentScope(
-          new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib)
+          new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib, 
this.ctx)
         );
       }
       case ArgTypeCode.TVMDLTensorHandle: {
         assert(callbackArg);
         // no need to attach as we are only looking at view
-        return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib);
+        return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib, 
this.ctx);
       }
       case ArgTypeCode.TVMPackedFuncHandle: {
         return this.ctx.attachToCurrentScope(

Reply via email to