This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 89d539d746 [Unity] Update specific builtins for LM (#14617)
89d539d746 is described below
commit 89d539d746f8ba0133a72cd9f74626a8b2013ce5
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Apr 14 15:56:16 2023 -0400
[Unity] Update specific builtins for LM (#14617)
This PR updates the specific builtins for LM
and move them to lm_support.cc
The kv_create now takes an initial data and copies it instead of consumes
it.
This will enable us to create kv within a VM more easily.
---
python/tvm/exec/rpc_proxy.py | 2 +-
src/runtime/relax_vm/builtin.cc | 2 +
.../{attention_kv_cache.cc => lm_support.cc} | 88 ++++++++++++++++++++--
src/target/source/codegen_webgpu.cc | 5 ++
tests/python/relax/test_pipeline.py | 19 ++++-
tests/python/relax/test_runtime_builtin.py | 2 +-
web/emcc/wasm_runtime.cc | 2 +-
web/src/runtime.ts | 73 ++++++++++++++----
8 files changed, 167 insertions(+), 26 deletions(-)
diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py
index eaa406906a..343588a400 100644
--- a/python/tvm/exec/rpc_proxy.py
+++ b/python/tvm/exec/rpc_proxy.py
@@ -35,7 +35,7 @@ def find_example_resource():
("/", os.path.join(base_path, "web", "dist", "wasm",
"tvmjs_runtime.wasi.js")),
("/", index_page),
]
- allow_format = ("json", "bin", "js", "wasm", "html", "css")
+ allow_format = ("json", "bin", "js", "wasm", "html", "css", "model")
# recursively apend things in www, up to two levels
resource_bases = [
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 5a7c1d6620..af0963bf41 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -57,6 +57,8 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) {
// TODO(relax-team): visit and consider other possible choices.
if (vm->devices[0].device_type == kDLHexagon) {
host_device_index = 0;
+ } else {
+ ICHECK_EQ(vm->devices[host_device_index].device_type, kDLCPU);
}
auto* alloc = vm->allocators[host_device_index];
return alloc->Empty({size}, DLDataType{kDLInt, 64, 1},
vm->devices[host_device_index]);
diff --git a/src/runtime/relax_vm/attention_kv_cache.cc
b/src/runtime/relax_vm/lm_support.cc
similarity index 66%
rename from src/runtime/relax_vm/attention_kv_cache.cc
rename to src/runtime/relax_vm/lm_support.cc
index f94b233975..9470b86e1a 100644
--- a/src/runtime/relax_vm/attention_kv_cache.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -17,8 +17,10 @@
* under the License.
*/
/*!
- * \file src/runtime/relax_vm/attention_kv_cache.cc
- * \brief A simple implementation of inplace attention kv cache for runtime.
+ * \file src/runtime/relax_vm/lm_support.cc
+ * \brief Runtime to support language model related task
+ *
+ * Including inplace attention kv cache for runtime and simple sampler.
*
* This file provides a simple implementation of inplace attention
* kv cache for relax runtime. The main goal here is to help us enable
@@ -33,7 +35,6 @@
*
* We can evolve this implementation as we build more LM verticals.
*/
-
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
@@ -41,6 +42,8 @@
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/relax_vm/vm.h>
+#include <cmath>
+
namespace tvm {
namespace runtime {
namespace relax_vm {
@@ -79,6 +82,9 @@ class AttentionKVCacheObj : public Object {
return data.CreateView(shape, data->dtype);
}
+ /** Clear the cache */
+ void Clear() { this->fill_count = 0; }
+
/*!
* \brief Append value to the cache.
* \param value The value to be appended.
@@ -124,12 +130,16 @@ class AttentionKVCache : public ObjectRef {
public:
/*!
* \brief Create the attention kv cache.
- * \param init_reserve The initial reserved.
+ * \param init_data The initial reserved.
*/
- static AttentionKVCache Create(NDArray init_data) {
+ static AttentionKVCache Create(NDArray init_data, ShapeTuple reserve_shape,
int init_fill_count) {
auto n = make_object<AttentionKVCacheObj>();
- n->data = std::move(init_data);
+ n->data = NDArray::Empty(reserve_shape, init_data->dtype,
init_data->device);
n->fill_count = 0;
+ n->Append(init_data);
+ if (init_fill_count >= 0) {
+ n->fill_count = init_fill_count;
+ }
return AttentionKVCache(n);
}
@@ -157,6 +167,72 @@ NDArray AttentionKVCacheView(AttentionKVCache cache,
ShapeTuple shape) {
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view").set_body_typed(AttentionKVCacheView);
+void AttentionKVCacheArrayClear(Array<AttentionKVCache> caches) {
+ for (AttentionKVCache cache : caches) {
+ cache->Clear();
+ }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_clear")
+ .set_body_typed(AttentionKVCacheArrayClear);
+
+// NOTE this is a built-in highly related to LM so we put it here.
+int SampleTopPFromLogits(NDArray logits, double temperature, double top_p,
double uniform_sample) {
+ ICHECK(logits.IsContiguous());
+ ICHECK(logits.DataType() == DataType::Float(32));
+
+ if (logits->device.device_type != kDLCPU) {
+ logits = logits.CopyTo(DLDevice{kDLCPU, 0});
+ }
+
+ ICHECK(logits->device.device_type == kDLCPU);
+
+ for (int i = 0; i < logits->ndim - 1; ++i) {
+ ICHECK_EQ(logits->shape[i], 1) << "The leading dimensions of logits must
be 1";
+ }
+
+ std::vector<std::pair<float, int>> data;
+ data.resize(logits->shape[logits->ndim - 1]);
+ const float* plogits = static_cast<float*>(logits->data);
+ for (size_t i = 0; i < data.size(); ++i) {
+ data[i] = std::make_pair(plogits[i], static_cast<int>(i));
+ }
+ // sort by logits from smallest to largest
+ std::sort(data.begin(), data.end());
+ float max_value = data.back().first;
+ // argmax
+ if (temperature < 1e-6f) {
+ return data.back().second;
+ }
+ // compute expf
+ float sum = 0.0f;
+ for (size_t i = 0; i < data.size(); ++i) {
+ data[i].first = expf(data[i].first - max_value);
+ sum += data[i].first;
+ }
+ // do a cumsum in order of data
+ float cum_sum_prob = 0.0f;
+ float top_p_sum = 0.0f;
+ for (auto rit = data.rbegin(); rit != data.rend(); ++rit) {
+ float prob = rit->first / sum;
+ if (cum_sum_prob < top_p) {
+ top_p_sum += prob;
+ }
+ cum_sum_prob += prob;
+ rit->first = cum_sum_prob;
+ }
+ // pick a number based on random in (0, 1)
+ for (auto rit = data.rbegin(); rit != data.rend(); ++rit) {
+ if (uniform_sample < rit->first / top_p_sum) {
+ return rit->second;
+ }
+ }
+ ICHECK_LE(uniform_sample, data[0].first);
+ return data[0].second;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits);
+
} // namespace relax_vm
} // namespace runtime
} // namespace tvm
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index d56a5d547f..95c46b8894 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -117,6 +117,11 @@ CodeGenWebGPU::CodeGenWebGPU(Target target) :
target_(target) {}
runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool
skip_readonly_decl) {
// clear previous generated state.
this->InitFuncState(f);
+ // reserve keywords
+ name_supply_->ReserveName("var");
+ name_supply_->ReserveName("let");
+ name_supply_->ReserveName("const");
+
// skip the first underscore, so SSA variable starts from
name_supply_->FreshName("v_");
// Setup the thread group info.
diff --git a/tests/python/relax/test_pipeline.py
b/tests/python/relax/test_pipeline.py
index 3c97d6b701..2dac42b334 100644
--- a/tests/python/relax/test_pipeline.py
+++ b/tests/python/relax/test_pipeline.py
@@ -52,6 +52,20 @@ def test_pipeline_with_kv_cache():
@tvm.script.ir_module
class Mod:
+ @R.function
+ def create_kv_cache(reserve_slots: R.Shape(["m"])):
+ # just allocate minimum slot since it is only used to signal dtype
+ m = T.int64()
+ init_data = R.ones((1, 4), "float32")
+ kv_cache = R.call_packed(
+ "vm.builtin.attention_kv_cache_create",
+ init_data,
+ R.shape([m, 4]),
+ 0,
+ sinfo_args=[R.Object],
+ )
+ return kv_cache
+
@R.function
def main(
x: R.Tensor((1, 4), "float32"),
@@ -84,11 +98,10 @@ def test_pipeline_with_kv_cache():
num_steps = 8
cache_np = np.empty((num_steps, 4), dtype="float32")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
- fcreate_cache = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
- kv_cache = fcreate_cache(tvm.nd.empty((2, 4), device=tvm.cpu(),
dtype="float32"))
+ kv_cache = vm["create_kv_cache"](tvm.runtime.ShapeTuple([1]))
- vm = relax.VirtualMachine(ex, tvm.cpu())
for i in range(num_steps):
x_np = np.random.rand(1, 4).astype(np.float32)
y_np = np.random.rand(1, 4).astype(np.float32)
diff --git a/tests/python/relax/test_runtime_builtin.py
b/tests/python/relax/test_runtime_builtin.py
index bb513eb357..6ba06d0693 100644
--- a/tests/python/relax/test_runtime_builtin.py
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -155,7 +155,7 @@ def test_attention_kv_cache():
fappend = tvm.get_global_func("vm.builtin.attention_kv_cache_append")
fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
- cache = fcreate(tvm.nd.empty((2, 2), dtype="int32"))
+ cache = fcreate(tvm.nd.empty((1, 2), dtype="int32"),
tvm.runtime.ShapeTuple([2, 2]), 0)
num_steps = 0
for i in range(num_steps):
cache = fappend(cache, tvm.nd.array(i * np.ones((1,
2).astype("int32"))))
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index 9b2b6c180c..2412fb8d2d 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -53,10 +53,10 @@
#include "src/runtime/system_library.cc"
#include "src/runtime/workspace_pool.cc"
// relax setup
-#include "src/runtime/relax_vm/attention_kv_cache.cc"
#include "src/runtime/relax_vm/builtin.cc"
#include "src/runtime/relax_vm/bytecode.cc"
#include "src/runtime/relax_vm/executable.cc"
+#include "src/runtime/relax_vm/lm_support.cc"
#include "src/runtime/relax_vm/memory_manager.cc"
#include "src/runtime/relax_vm/vm.cc"
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 427bb2d8f8..26ba1ecd67 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -143,13 +143,16 @@ class RuntimeContext implements Disposable {
arrayGetItem : PackedFunc;
arrayGetSize : PackedFunc;
arrayMake : PackedFunc;
- getSysLib: PackedFunc;
- arrayCacheGet: PackedFunc;
- arrayCacheUpdate: PackedFunc;
- arrayCacheRemove: PackedFunc;
- arrayCacheClear: PackedFunc;
- arrayDecodeStorage: PackedFunc;
- paramModuleFromCache: PackedFunc;
+ getSysLib : PackedFunc;
+ arrayCacheGet : PackedFunc;
+ arrayCacheUpdate : PackedFunc;
+ arrayCacheRemove : PackedFunc;
+ arrayCacheClear : PackedFunc;
+ arrayDecodeStorage : PackedFunc;
+ paramModuleFromCache : PackedFunc;
+ makeShapeTuple : PackedFunc;
+ ndarrayCreateView : PackedFunc;
+ sampleTopPFromLogits : PackedFunc;
private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
@@ -164,12 +167,14 @@ class RuntimeContext implements Disposable {
this.arrayCacheClear = getGlobalFunc("tvmjs.ndarray_cache.clear");
this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
this.paramModuleFromCache = getGlobalFunc("tvmjs.param_module_from_cache");
-
+ this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
+ this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
+ this.sampleTopPFromLogits =
getGlobalFunc("vm.builtin.sample_top_p_from_logits");
}
dispose(): void {
// call array cache clear to clear all cached items
- this.arrayCacheClear();
+ this.arrayCacheClear.dispose();
this.arrayGetItem.dispose();
this.arrayGetSize.dispose();
this.arrayMake.dispose();
@@ -179,6 +184,9 @@ class RuntimeContext implements Disposable {
this.arrayCacheClear.dispose();
this.arrayDecodeStorage.dispose();
this.paramModuleFromCache.dispose();
+ this.makeShapeTuple.dispose();
+ this.ndarrayCreateView.dispose();
+ this.sampleTopPFromLogits.dispose();
}
beginScope() : void {
@@ -419,12 +427,14 @@ export class NDArray implements Disposable {
private dltensor: Pointer;
private dataPtr: Pointer;
private lib: FFILibrary;
+ private ctx: RuntimeContext;
private dlDataType: DLDataType;
- constructor(handle: Pointer, isView: boolean, lib: FFILibrary) {
+ constructor(handle: Pointer, isView: boolean, lib: FFILibrary, ctx:
RuntimeContext) {
this.handle = handle;
this.isView = isView;
this.lib = lib;
+ this.ctx = ctx;
if (this.isView) {
this.dltensor = handle;
@@ -470,6 +480,16 @@ export class NDArray implements Disposable {
this.byteOffset = lib.memory.loadI64(this.dltensor +
arrayOffsetByteOffset);
}
+ /**
+ * Create a view of the array.
+ * @param shape The shape of the view.
+ * @returns The new sliced ndarray.
+ */
+ view(shape: Array<number>) : NDArray {
+ const shapeArray = shape.map((value) => new Scalar(value, "int"));
+ return this.ctx.ndarrayCreateView(this,
this.ctx.makeShapeTuple(...shapeArray));
+ }
+
/**
* Get handle of ndarray, check it is not null.
*
@@ -870,7 +890,11 @@ export class VirtualMachine implements Disposable {
this.mod.getFunction("vm_initialization")(
new Scalar(device.deviceType, "int"),
new Scalar(device.deviceId, "int"),
- new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int")
+ new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"),
+ // explicitly specify host device type
+ new Scalar(DeviceStrToEnum.cpu, "int"),
+ new Scalar(0, "int"),
+ new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"),
);
}
@@ -1581,7 +1605,7 @@ export class Instance implements Disposable {
)
);
const ret = this.ctx.attachToCurrentScope(
- new NDArray(this.memory.loadPointer(outPtr), false, this.lib)
+ new NDArray(this.memory.loadPointer(outPtr), false, this.lib, this.ctx)
);
this.lib.recycleCallStack(stack);
return ret;
@@ -1614,6 +1638,18 @@ export class Instance implements Disposable {
return ret.copyFrom(input);
}
+ /**
+ * Sample index via top-p sampling.
+ *
+ * @param logits The input logits before normalization.
+ * @param temperature The temperature factor, will take argmax if
temperature = 0.0
+ * @param top_p The top_p
+ * @returns The sampled index.
+ */
+ sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number):
number {
+ return this.ctx.sampleTopPFromLogits(logits, temperature, top_p,
Math.random());
+ }
+
/**
* Bind canvas to the current WebGPU context
* @param canvas The canvas.
@@ -1668,6 +1704,15 @@ export class Instance implements Disposable {
return this.ctx.arrayMake(...inputs) as TVMArray;
}
+ /**
+ * Create a shape tuple to pass to runtime.
+ * @param shape The shape .
+ * @returns The created shape tuple.
+ */
+ makeShapeTuple(shape: Array<number>) : TVMObject {
+ const shapeArray = shape.map((value) => new Scalar(value, "int"));
+ return this.ctx.makeShapeTuple(...shapeArray);
+ }
/**
* Get type index from type key.
* @param typeKey The type key.
@@ -2131,13 +2176,13 @@ export class Instance implements Disposable {
}
case ArgTypeCode.TVMNDArrayHandle: {
return this.ctx.attachToCurrentScope(
- new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib)
+ new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib,
this.ctx)
);
}
case ArgTypeCode.TVMDLTensorHandle: {
assert(callbackArg);
// no need to attach as we are only looking at view
- return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib);
+ return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib,
this.ctx);
}
case ArgTypeCode.TVMPackedFuncHandle: {
return this.ctx.attachToCurrentScope(