This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3ef478b486 [Relax][Runtime] RNNState for Space State Models (#16568)
3ef478b486 is described below
commit 3ef478b4864c110c33fde937b9f6b8e604e957d3
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Feb 21 21:32:23 2024 +0800
[Relax][Runtime] RNNState for Space State Models (#16568)
* [Relax][Runtime] RNNState for Space State Models
This commit adds the RNNState class to the Relax VM, similar to the
PagedKVCache, for space state models like RWKV and mamba
* refactor
---
src/runtime/relax_vm/kv_state.cc | 80 ++++
src/runtime/relax_vm/{kv_cache.h => kv_state.h} | 118 +++--
src/runtime/relax_vm/lm_support.cc | 11 +-
src/runtime/relax_vm/paged_kv_cache.cc | 41 +-
src/runtime/relax_vm/rnn_state.cc | 487 +++++++++++++++++++++
.../python/relax/test_runtime_builtin_rnn_state.py | 262 +++++++++++
6 files changed, 947 insertions(+), 52 deletions(-)
diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc
new file mode 100644
index 0000000000..7c86e96ec6
--- /dev/null
+++ b/src/runtime/relax_vm/kv_state.cc
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "kv_state.h"
+
+#include <utility>
+
+namespace tvm {
+namespace runtime {
+namespace relax_vm {
+
+// Register Object Type
+TVM_REGISTER_OBJECT_TYPE(KVStateObj);
+TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
+TVM_REGISTER_OBJECT_TYPE(RNNStateObj);
+
+// KV State base methods
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method<KVState>(&KVStateObj::Clear);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence")
+ .set_body_method<KVState>(&KVStateObj::AddSequence);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence")
+ .set_body_method<KVState>(&KVStateObj::RemoveSequence);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
+ .set_body_method<KVState>(&KVStateObj::ForkSequence);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
+ .set_body_method<KVState>(&KVStateObj::BeginForward);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
+ .set_body_method<KVState>(&KVStateObj::EndForward);
+
+// Attention KV Cache methods
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention")
+ .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
+ double attn_score_scaling_factor, NDArray q_data,
NDArray k_data,
+ NDArray v_data, NDArray o_data) {
+ kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data),
std::move(v_data),
+ NullOpt, std::move(o_data),
attn_score_scaling_factor);
+ });
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
+ .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
+ double attn_score_scaling_factor, NDArray qkv_data,
NDArray o_data) {
+ kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt,
std::move(o_data),
+ attn_score_scaling_factor);
+ });
+
+// RNN State methods
+TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
+TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
+ .set_body_typed([](RNNState state, int64_t layer_id, int64_t state_id,
NDArray data) {
+ state->Set(layer_id, state_id, data);
+ return state;
+ });
+TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get")
+ .set_body_method<RNNState>(&RNNStateObj::DebugGet);
+
+} // namespace relax_vm
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/relax_vm/kv_cache.h b/src/runtime/relax_vm/kv_state.h
similarity index 74%
rename from src/runtime/relax_vm/kv_cache.h
rename to src/runtime/relax_vm/kv_state.h
index 82e32b3af5..5f824a84b1 100644
--- a/src/runtime/relax_vm/kv_cache.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -16,30 +16,29 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
-#define TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
+#ifndef TVM_RUNTIME_RELAX_VM_KV_STATE_H_
+#define TVM_RUNTIME_RELAX_VM_KV_STATE_H_
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
+#include "tvm/runtime/object.h"
+
namespace tvm {
namespace runtime {
namespace relax_vm {
-/*!
- * \brief The base class of attention KV cache for efficient
- * k/v data management and attention computation.
- */
-class AttentionKVCache : public Object {
+/*! \brief The base class of attention KV cache and rnn state. */
+class KVStateObj : public Object {
public:
- /*! \brief Reset the KV cache. */
+ /*! \brief Reset the KV State. */
virtual void Clear() = 0;
/************** Sequence Management **************/
/*!
- * \brief Add a new sequence with empty K/V data in the cache.
+ * \brief Add a new sequence with empty K/V state in the cache.
* Check if the validity of the input sequence id.
* \param seq_id The id of the new sequence to be added.
* \throws Error if the given sequence id is not valid.
@@ -47,15 +46,15 @@ class AttentionKVCache : public Object {
virtual void AddSequence(int64_t seq_id) = 0;
/*!
- * \brief Remove a sequence and its K/V data from the KV cache.
+ * \brief Remove a sequence and its K/V state from the KV cache.
* \param seq_id The sequence to remove from cache.
* \throws Error if the given sequence id is not valid.
*/
virtual void RemoveSequence(int64_t seq_id) = 0;
/*!
- * \brief Fork the K/V data of parent sequence to the child sequence.
- * After the fork, the child sequence has K/V data of the parent
+ * \brief Fork the K/V state of parent sequence to the child sequence.
+ * After the fork, the child sequence has K/V state of the parent
* sequence.
* \param parent_seq_id The parent (source) of the fork.
* \param child_seq_id The child (destination) of the fork.
@@ -73,18 +72,6 @@ class AttentionKVCache : public Object {
*/
virtual void PopN(int64_t seq_id, int32_t n) = 0;
- /************** Raw Info Query **************/
-
- /*!
- * \brief Get the number of available pages in the KV cache.
- * When the underlying KV cache implementation is not
- * paged KV cache, the function falls back to return the
- * number of remaining size (in terms of number of tokens).
- */
- virtual int32_t GetNumAvailablePages() const = 0;
-
- /************** Attention **************/
-
/*!
* \brief Mark the start of the forward function with the ids of
* the sequences and the sequence length to forward for each
@@ -109,6 +96,34 @@ class AttentionKVCache : public Object {
*/
virtual void EndForward() = 0;
+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+ static constexpr const char* _type_key = "relax.vm.KVState";
+ TVM_DECLARE_BASE_OBJECT_INFO(KVStateObj, Object)
+};
+
+class KVState : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(KVState, ObjectRef, KVStateObj);
+};
+
+/*!
+ * \brief The base class of attention KV cache for efficient
+ * k/v data management and attention computation.
+ */
+class AttentionKVCacheObj : public KVStateObj {
+ public:
+ /************** Raw Info Query **************/
+
+ /*!
+ * \brief Get the number of available pages in the KV cache.
+ * When the underlying KV cache implementation is not
+ * paged KV cache, the function falls back to return the
+ * number of remaining size (in terms of number of tokens).
+ */
+ virtual int32_t GetNumAvailablePages() const = 0;
+
+ /************** Attention **************/
+
/*!
* \brief Compute attention with the given Q/K/V data at the specified
* layer with regard to the previously reserved append lengths.
@@ -197,10 +212,63 @@ class AttentionKVCache : public Object {
* \param v_data The V data to set in layout elaborated above.
*/
virtual void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data,
NDArray v_data) = 0;
+
+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+ static constexpr const char* _type_key = "relax.vm.AttentionKVCache";
+ TVM_DECLARE_BASE_OBJECT_INFO(AttentionKVCacheObj, KVStateObj);
+};
+
+class AttentionKVCache : public KVState {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCache, KVState,
AttentionKVCacheObj);
+};
+
+/*!
+ * \brief The base class of RNN State for efficient
+ * State data management and attention computation.
+ */
+class RNNStateObj : public KVStateObj {
+ public:
+ /************** Interaction **************/
+ /*!
+ * \brief Get the State data for the specified sequence.
+ * \param layer_id The model layer where the state is set.
+ * \param state_id The state id within the layer.
+ * \param o_data The output data to be fetched.
+ * \return The array of State data, each element corresponds to a state.
+ * \throws Error if the given sequence id is not valid.
+ */
+ virtual void Get(int64_t layer_id, int64_t state_id, NDArray o_data) = 0;
+
+ /*!
+ * \brief Set the State data for the specified sequence.
+ * \param layer_id The model layer where the state is set.
+ * \param state_id The state id within the layer.
+ * \param data The data to be set.
+ * \throws Error if the given sequence id is not valid.
+ */
+ virtual void Set(int64_t layer_id, int64_t state_id, NDArray data) = 0;
+
+ /*!
+ * \brief Fetch the compact rnn state data of the given sequence.
+ * \param layer_id The model layer where the state is set.
+ * \param state_id The state id within the layer.
+ * \param seq_id The sequence whose state data is to be fetched.
+ */
+ virtual NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id)
= 0;
+
+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+ static constexpr const char* _type_key = "relax.vm.RNNState";
+ TVM_DECLARE_BASE_OBJECT_INFO(RNNStateObj, KVStateObj);
+};
+
+class RNNState : public KVState {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RNNState, KVState, RNNStateObj);
};
} // namespace relax_vm
} // namespace runtime
} // namespace tvm
-#endif // TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
+#endif // TVM_RUNTIME_RELAX_VM_KV_STATE_H_
diff --git a/src/runtime/relax_vm/lm_support.cc
b/src/runtime/relax_vm/lm_support.cc
index fccff2cecd..cfb78006d7 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -59,7 +59,7 @@ namespace relax_vm {
/*!
* \brief An object representing an attention kv cache.
*/
-class AttentionKVCacheObj : public Object {
+class AttentionKVCacheLegacyObj : public Object {
public:
/*!
* \brief Underlying support data.
@@ -227,7 +227,7 @@ class AttentionKVCacheObj : public Object {
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy";
- TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheObj, Object);
+ TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheLegacyObj, Object);
};
/*! \brief reference to closure. */
@@ -239,7 +239,7 @@ class AttentionKVCacheLegacy : public ObjectRef {
*/
static AttentionKVCacheLegacy Create(NDArray init_data, ShapeTuple
reserve_shape,
int init_fill_count) {
- auto n = make_object<AttentionKVCacheObj>();
+ auto n = make_object<AttentionKVCacheLegacyObj>();
n->data = NDArray::Empty(reserve_shape, init_data->dtype,
init_data->device);
n->fill_count = 0;
n->Append(init_data);
@@ -250,10 +250,11 @@ class AttentionKVCacheLegacy : public ObjectRef {
return AttentionKVCacheLegacy(n);
}
- TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef,
AttentionKVCacheObj);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef,
+ AttentionKVCacheLegacyObj);
};
-TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
+TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheLegacyObj);
//-------------------------------------------------
// Register runtime functions
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 70fa3daee7..f848ed2490 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -29,7 +29,7 @@
#include <utility>
#include <vector>
-#include "kv_cache.h"
+#include "kv_state.h"
namespace tvm {
namespace runtime {
@@ -183,7 +183,7 @@ enum class RoPEMode : int {
* After calling `EndForward`, it is required to call `BeginForward`
* before calling any `Attention`.
*/
-class PagedAttentionKVCacheObj : public AttentionKVCache {
+class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
private:
/********************* Configuration *********************/
@@ -810,7 +810,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.PagedAttentionKVCache";
- TVM_DECLARE_FINAL_OBJECT_INFO(PagedAttentionKVCacheObj, Object);
+ TVM_DECLARE_FINAL_OBJECT_INFO(PagedAttentionKVCacheObj, AttentionKVCacheObj);
private:
/*! \brief Get a new free page and return its id. */
@@ -1157,11 +1157,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCache
{
}
};
-class PagedAttentionKVCache : public ObjectRef {
- public:
- TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PagedAttentionKVCache, ObjectRef,
PagedAttentionKVCacheObj);
-};
-
TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
//-------------------------------------------------
@@ -1199,7 +1194,7 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
std::move(f_attention_decode_begin_forward),
std::move(f_attention_decode_end_forward),
std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_rotary_inplace),
std::move(f_debug_get_kv));
- return PagedAttentionKVCache(std::move(n));
+ return AttentionKVCache(std::move(n));
});
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
@@ -1224,38 +1219,40 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,
//
std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_rotary_inplace),
std::move(f_debug_get_kv));
- return PagedAttentionKVCache(std::move(n));
+ return AttentionKVCache(std::move(n));
});
+// Keep the following global functions for backward compatibility.
+// TODO(tvm-team): Remove these global functions in the future.
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear")
- .set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::Clear);
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Clear);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence")
-
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::AddSequence);
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::AddSequence);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_remove_sequence")
-
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::RemoveSequence);
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::RemoveSequence);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_fork_sequence")
-
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::ForkSequence);
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::ForkSequence);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_popn")
- .set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::PopN);
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::PopN);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_num_available_pages")
-
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::GetNumAvailablePages);
+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_begin_forward")
-
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::BeginForward);
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::BeginForward);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_end_forward")
-
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::EndForward);
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EndForward);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions")
-
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::GetQueryPositions);
+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv")
-
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::DebugGetKV);
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")
- .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
+ .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray q_data,
NDArray k_data,
NDArray v_data, NDArray o_data) {
kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data),
std::move(v_data),
NullOpt, std::move(o_data),
attn_score_scaling_factor);
});
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv")
- .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
+ .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray qkv_data,
NDArray o_data) {
kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt,
std::move(o_data),
attn_score_scaling_factor);
diff --git a/src/runtime/relax_vm/rnn_state.cc
b/src/runtime/relax_vm/rnn_state.cc
new file mode 100644
index 0000000000..09873ba5f7
--- /dev/null
+++ b/src/runtime/relax_vm/rnn_state.cc
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/runtime/relax_vm/rnn_state.cc
+ * \brief Runtime RNN state object for space state models.
+ */
+
+#include <cstdint>
+#include <vector>
+
+#include "kv_state.h"
+
+namespace tvm {
+namespace runtime {
+namespace relax_vm {
+
+//-----------------------------------------------------------------------------
+// We keep the implementation private as they may subject to future changes.
+//
+// Users can interact with it through the runtime API function calls
+//-----------------------------------------------------------------------------
+
+class RNNStateImpObj : public RNNStateObj {
+ private:
+ /********************* Data Structures *********************/
+
+ /*!
+ * \brief The sequence structure in paged KV cache with common prefix
support.
+ * Each sequence contains one or more blocks to support common prefix.
+ */
+ struct Sequence {
+ /*! \brief The total sequence length of the sequence. */
+ int64_t seq_length = 0;
+ /*! \brief The available history length for rolling back. */
+ int64_t available_history_num = 0;
+ /*! \brief The index of history slot in the storage. */
+ int64_t history_slot_id = 0;
+ /*! \brief The index of seq slot in the storage. */
+ int64_t seq_slot_id;
+
+ /*! \brief Constructor. */
+ explicit Sequence(int64_t seq_slot_id) : seq_slot_id(seq_slot_id) {}
+
+ static Sequence Fork(const Sequence& parent, int64_t seq_slot_id) {
+ Sequence child = parent;
+ child.seq_slot_id = seq_slot_id;
+ return child;
+ }
+ };
+
+ /********************* Configuration *********************/
+
+ /*! \brief The number of layers in the model. */
+ const int64_t num_layers_;
+ /*! \brief The max number of sequences in the storage. */
+ const int64_t reserved_num_seqs_;
+ /*! \brief The number of states per layer. */
+ const int64_t num_states_per_layer_;
+ /*! \brief The max history length for rolling back. */
+ const int64_t max_history_ = 1;
+ /*!
+ * \brief The init value for ALL layer in the storage.
+ * The array has `num_states_per_layer_` NDArrays
+ */
+ const Array<NDArray> init_layer_value_;
+
+ /*! \brief We fix int32 to be the index dtype of auxiliary data. */
+ const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1));
+
+ /******************* Storage Structures *******************/
+
+ /*!
+ * \brief The storages of space state models.
+ * The array has `num_layers * num_states_per_layer_` NDArrays,
+ * each of them has layout `(num_seq, max_history, state_size)`.
+ * \note As `num_states_per_layer_` may vary for different dtype and shape,
+ * we use a 2D array to store the NDArrays for each layer.
+ */
+ Array<Array<NDArray>> storages_;
+ /*! \brief The list of ids of released seq slot for reuse. */
+ std::vector<int64_t> free_slot_ids_;
+ /*! \brief The mapping from sequence ids to sequences. */
+ std::unordered_map<int64_t, Sequence> seq_map_;
+
+ /****************** Auxiliary Arrays on Host ******************/
+
+ /*! \brief The batch size of the current round of forwarding. */
+ int64_t cur_batch_size_;
+ /*! \brief The append lengths of the sequences in the current round of
forwarding. */
+ IntTuple cur_append_lengths_;
+ /*! \brief The sequence ids of the current round of forwarding. */
+ IntTuple cur_seq_ids_;
+
+ /**************** Auxiliary Arrays on Device *****************/
+
+ /*!
+ * \brief A boolean flag indicating if the auxiliary arrays are dirty.
+ * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked.
+ */
+ bool dirty_aux_data_device_ = false;
+ /*! \brief The device array of the sequence ids. */
+ NDArray seq_slot_ids_device_;
+ /*!
+ * \brief The view of the device array of the sequence ids.
+ * The view is used to reuse the memory but with different shape.
+ */
+ NDArray seq_slot_ids_view_;
+ /*! \brief The device array of the history slot ids. */
+ NDArray history_slot_ids_device_;
+ /*!
+ * \brief The view of the device array of the history slot ids.
+ * The view is used to reuse the memory but with different shape.
+ */
+ NDArray history_slot_ids_view_;
+
+ /******************* Interaction Functions *******************/
+
+ /*!
+ * \brief The function to get the state data from the storage.
+ * The function signature is `f_get_(state, seq_slot_ids, history_slot_ids,
out_data)`.
+ * and return the contiguous batched state data.
+ * \note Each state data per layer may have different dtype and shape, so we
use a
+ * different function for each state data.
+ */
+ Array<PackedFunc> f_gets_;
+ /*!
+ * \brief The function to set the state data to the storage.
+ * The function signature is `f_set_(state, seq_slot_ids, history_slot_ids,
data, max_history)`.
+ * where `state` is the storage NDArray, `seq_slot_ids` and
`history_slot_ids` are
+ * 1-D int32 arrays of the same length as the batch size, and `data` is the
input data.
+ * \note The `history_slot_ids` is the slot of this round, but we need to
write to the
+ * slot of the next round.
+ * \note Each state data per layer may have different dtype and shape, so we
use a
+ * different function for each state data.
+ */
+ Array<PackedFunc> f_sets_;
+
+ public:
+ /*! \brief Constructor. Take the cache configuration and initialize the
NDArrays. */
+ explicit RNNStateImpObj(int64_t num_layers, //
+ int64_t reserved_num_seqs, //
+ int64_t max_history, //
+ DLDevice device, //
+ Array<PackedFunc> f_gets, //
+ Array<PackedFunc> f_sets, //
+ Array<NDArray> init_layer_value)
+ : num_layers_(num_layers),
+ reserved_num_seqs_(reserved_num_seqs),
+ num_states_per_layer_(init_layer_value.size()),
+ max_history_(max_history),
+ init_layer_value_(init_layer_value),
+ f_gets_(std::move(f_gets)),
+ f_sets_(std::move(f_sets)) {
+ // Allocate the storage for the space state models.
+ storages_.reserve(num_layers_);
+ for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
+ Array<NDArray> layer_storages;
+ layer_storages.reserve(num_states_per_layer_);
+ for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id)
{
+ ShapeTuple state_shape = init_layer_value[state_id].Shape();
+ std::vector<ShapeTupleObj::index_type> storage_shape =
{reserved_num_seqs, max_history};
+ storage_shape.insert(storage_shape.end(), state_shape.begin(),
state_shape.end());
+ NDArray state_storage =
+ NDArray::Empty(storage_shape,
init_layer_value[state_id].DataType(), device);
+ layer_storages.push_back(state_storage);
+ }
+ storages_.push_back(layer_storages);
+ }
+
+ CHECK_GT(max_history_, 0) << "At least 1 history slot to store the current
state";
+
+ // Allocate the auxiliary arrays on device.
+ seq_slot_ids_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_,
device);
+ history_slot_ids_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_,
device);
+
+ Clear();
+ }
+
+ /*! \brief Reset the KV cache. */
+ void Clear() final {
+ seq_map_.clear();
+ ICHECK(!storages_.empty());
+ free_slot_ids_.clear();
+ for (int64_t slot_id = reserved_num_seqs_ - 1; slot_id >= 0; --slot_id) {
+ free_slot_ids_.push_back(slot_id);
+ }
+ dirty_aux_data_device_ = false;
+ }
+
+ /************** Interaction **************/
+
+ void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) {
+ CHECK_EQ(seq_ids.size(), append_lengths.size())
+ << "The seq_ids size (" << seq_ids.size() << ") and append_lengths
size ("
+ << append_lengths.size() << ") mismatch.";
+ cur_batch_size_ = seq_ids.size();
+ cur_append_lengths_ = append_lengths;
+ cur_seq_ids_ = seq_ids;
+
+ if (dirty_aux_data_device_) {
+ SyncAuxArrayToDevice();
+ }
+ }
+
+ void EndForward() final {
+ for (int64_t i = 0; i < cur_batch_size_; ++i) {
+ int64_t seq_id = cur_seq_ids_[i];
+ int64_t seq_length = cur_append_lengths_[i];
+ auto it = seq_map_.find(seq_id);
+ CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+ << "\" cannot be found in the space state
storage.";
+ it->second.seq_length += seq_length;
+ if (seq_length > 1) {
+ // We cannot rollback the prefill input
+ it->second.available_history_num = 0;
+ } else {
+ it->second.available_history_num =
+ std::min(it->second.available_history_num + 1, max_history_ - 1);
+ }
+ it->second.history_slot_id = (it->second.history_slot_id + 1) %
max_history_;
+ }
+ // TODO(Siyuan): We need to update history_slot_id_device_ (on device) as
well.
+ // There are two ways to do this:
+ // 1. Update history_slot_id_device_ on device directly through a explict
kernel
+ // 2. Update history_slot_id on host and then sync to device.
+ // We choose the second way for now for convenience. But the first way is
more efficient.
+ dirty_aux_data_device_ = true;
+ }
+
+ void Get(int64_t layer_id, int64_t state_id, NDArray o_data) final {
+ // The auxiliary data structure on device must have been synchronized.
+ CHECK(!dirty_aux_data_device_)
+ << "The auxiliary arrays are not synchronized to device. Please call "
+ "`BeginForward` to synchronize before calling `Get`.";
+ ICHECK(cur_batch_size_ == static_cast<int64_t>(cur_seq_ids_.size()))
+ << "The batch size is not consistent with the number of sequence ids.";
+ CHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater
than 0.";
+ // TODO(siyuan): support zero-copy when seq_len is one
+ // Copy the state data to the return array.
+ NDArray state = storages_[layer_id][state_id];
+ f_gets_[state_id](state, seq_slot_ids_view_, history_slot_ids_view_,
o_data);
+ }
+
+ void Set(int64_t layer_id, int64_t state_id, NDArray data) final {
+ // The auxiliary data structure on device must have been synchronized.
+ CHECK(!dirty_aux_data_device_)
+ << "The auxiliary arrays are not synchronized to device. Please call "
+ "`BeginForward` to synchronize before calling `Set`.";
+ ICHECK(cur_batch_size_ == static_cast<int64_t>(cur_seq_ids_.size()))
+ << "The batch size is not consistent with the number of sequence ids.";
+ CHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater
than 0.";
+
+ NDArray state = storages_[layer_id][state_id];
+ f_sets_[state_id](state, seq_slot_ids_view_, history_slot_ids_view_, data);
+ }
+
+ NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) {
+ auto it = seq_map_.find(seq_id);
+ CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+ << "\" cannot be found in the space state
storage.";
+ NDArray state = storages_[layer_id][state_id];
+ int64_t seq_slot_id = it->second.seq_slot_id;
+ int64_t history_slot_id = it->second.history_slot_id;
+
+ std::vector<int64_t> shape{state.Shape().begin() + 2, state.Shape().end()};
+ NDArray result = NDArray::Empty(shape, state->dtype, state->device);
+ DLTensor copy_src = GetStatePtrBySeqHistory(layer_id, state_id,
seq_slot_id, history_slot_id);
+ DLTensor copy_dst = *result.operator->();
+
+ NDArray::CopyFromTo(©_src, ©_dst);
+ return result;
+ }
+
+ /************** Sequence Management **************/
+
+ void AddSequence(int64_t seq_id) final {
+ CHECK(seq_map_.find(seq_id) == seq_map_.end())
+ << "The sequence \"" << seq_id << "\" is already in the space state
storage.";
+ int64_t seq_slot_id = GetFreeSlot();
+ seq_map_.insert({seq_id, Sequence(seq_slot_id)});
+
+ // Initialize the state data with the init value.
+ for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
+ for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id)
{
+ DLTensor dst =
+ GetStatePtrBySeqHistory(layer_id, state_id, seq_slot_id,
/*history_slot_id=*/0);
+ NDArray init = init_layer_value_[state_id];
+ NDArray::CopyFromTo(init.operator->(), &dst);
+ }
+ }
+
+ dirty_aux_data_device_ = true;
+ }
+
+ void RemoveSequence(int64_t seq_id) final {
+ auto it = seq_map_.find(seq_id);
+ CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+ << "\" cannot be found in the space state
storage.";
+
+ free_slot_ids_.push_back(it->second.seq_slot_id);
+ seq_map_.erase(it);
+
+ dirty_aux_data_device_ = true;
+ }
+
+ void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
+ auto parent_it = seq_map_.find(parent_seq_id);
+ CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" <<
parent_seq_id
+ << "\" cannot be found in space state
storage.";
+ CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
+ << "The child sequence \"" << child_seq_id << "\" is already in the
space state storage.";
+
+ // Create a child block with the parent block pointer.
+ int64_t child_slot_id = GetFreeSlot();
+ seq_map_.insert({child_seq_id, Sequence::Fork(parent_it->second,
child_slot_id)});
+
+ // Copy the parent state data to the child state data.
+ int64_t parent_slot_id = parent_it->second.seq_slot_id;
+ for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
+ for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id)
{
+ DLTensor copy_src = GetStatePtrBySeq(layer_id, state_id,
parent_slot_id);
+ DLTensor copy_dst = GetStatePtrBySeq(layer_id, state_id,
child_slot_id);
+ NDArray::CopyFromTo(©_src, ©_dst);
+ }
+ }
+ dirty_aux_data_device_ = true;
+ }
+
+ void PopN(int64_t seq_id, int32_t n) final {
+ auto it = seq_map_.find(seq_id);
+ CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+ << "\" cannot be found in space state.";
+ CHECK_GE(n, 0) << "The length of rolling back " << n << " cannot be
negative.";
+ CHECK_LE(n, it->second.available_history_num)
+ << "The sequence only has " << it->second.available_history_num
+ << " available history in the space state storage, while the length of
rollback is " << n
+ << " which exceeds the sequence length.";
+
+ it->second.seq_length -= n;
+ it->second.available_history_num -= n;
+ it->second.history_slot_id = (it->second.history_slot_id - n +
max_history_) % max_history_;
+ dirty_aux_data_device_ = true;
+ }
+
+ private:
+ /*! \brief Get a new free block and return its index. */
+ int32_t GetFreeSlot() {
+ CHECK(!free_slot_ids_.empty()) << "The Sequence slot is full, cannot
accept new sequence.";
+ int32_t seq_slot_id = free_slot_ids_.back();
+ free_slot_ids_.pop_back();
+ return seq_slot_id;
+ }
+
+ DLTensor GetStatePtrBySeqHistory(int64_t layer_id, int64_t state_id, int64_t
seq_slot_id,
+ int64_t history_slot_id) {
+ NDArray state = storages_[layer_id][state_id];
+ int64_t state_size = 1;
+ for (int64_t i = 2; i < state->ndim; ++i) {
+ state_size *= state->shape[i];
+ }
+ int64_t elem_offset = (seq_slot_id * max_history_ + history_slot_id) *
state_size;
+ // Create a new DLTensor with the same shape and dtype as the state.
+ DLTensor _state = *(state.operator->());
+ _state.byte_offset = elem_offset * state->dtype.bits / 8;
+ _state.ndim = state->ndim - 2;
+ _state.shape = const_cast<int64_t*>(_state.shape + 2);
+ return _state;
+ }
+
+ DLTensor GetStatePtrBySeq(int64_t layer_id, int64_t state_id, int64_t
seq_slot_id) {
+ NDArray state = storages_[layer_id][state_id];
+ int64_t state_size = 1;
+ for (int64_t i = 1; i < state->ndim; ++i) {
+ state_size *= state->shape[i];
+ }
+ int64_t elem_offset = seq_slot_id * state_size;
+ // Create a new DLTensor with the same shape and dtype as the state.
+ DLTensor _state = *(state.operator->());
+ _state.byte_offset = elem_offset * state->dtype.bits / 8;
+ _state.ndim = state->ndim - 1;
+ _state.shape = const_cast<int64_t*>(_state.shape + 1);
+ return _state;
+ }
+
+ /*!
+ * \brief Synchronize auxiliary arrays to device.
+ * \note This method resets the dirty flag to false, and needs to be
+ * invoked before running attention computation on device.
+ */
+ void SyncAuxArrayToDevice() {
+ auto fcopy_from_vec = [](NDArray array, std::vector<int32_t> vec_data) {
+ DLTensor copy_dst = *array.operator->();
+ DLTensor copy_src;
+ copy_src.data = vec_data.data();
+ copy_src.device = Device{kDLCPU, 0};
+ copy_src.ndim = 1;
+ copy_src.dtype = array->dtype;
+ copy_src.shape = array->shape;
+ copy_src.strides = nullptr;
+ copy_src.byte_offset = 0;
+ NDArray::CopyFromTo(©_src, ©_dst);
+ };
+
+ std::vector<int32_t> seq_slot_ids;
+ std::vector<int32_t> history_slot_ids;
+ seq_slot_ids.reserve(cur_batch_size_);
+ history_slot_ids.reserve(cur_batch_size_);
+ for (int64_t seq_id : cur_seq_ids_) {
+ auto it = seq_map_.find(seq_id);
+ CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+ << "\" cannot be found in the space state
storage.";
+ const Sequence& seq = it->second;
+ seq_slot_ids.push_back(seq.seq_slot_id);
+ history_slot_ids.push_back(seq.history_slot_id);
+ }
+ seq_slot_ids_view_ = seq_slot_ids_device_.CreateView({cur_batch_size_},
dtype_aux_);
+ history_slot_ids_view_ =
history_slot_ids_device_.CreateView({cur_batch_size_}, dtype_aux_);
+
+ fcopy_from_vec(seq_slot_ids_view_, seq_slot_ids);
+ fcopy_from_vec(history_slot_ids_view_, history_slot_ids);
+
+ // Reset the dirty flag to false.
+ dirty_aux_data_device_ = false;
+ }
+
+ public:
+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+ static constexpr const char* _type_key = "relax.vm.RNNStateImp";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RNNStateImpObj, RNNStateObj);
+};
+
+TVM_REGISTER_OBJECT_TYPE(RNNStateImpObj);
+
+//-------------------------------------------------
+// Register runtime functions
+//-------------------------------------------------
+
+TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_create")
+ .set_body_typed([](int64_t num_layers, //
+ int64_t reserved_num_seqs, //
+ int64_t max_history, //
+ Array<PackedFunc> f_gets, //
+ Array<PackedFunc> f_sets, //
+ Array<NDArray> init_layer_value) {
+ CHECK_GT(num_layers, 0) << "The number of layers should be greater than
0.";
+ CHECK_GT(reserved_num_seqs, 0)
+ << "The number of reserved sequences should be greater than 0.";
+ CHECK_GE(max_history, 0) << "The maximum history length should be
greater or equal than 0.";
+ CHECK_GT(init_layer_value.size(), 0)
+ << "The number of states per layer should be greater than 0.";
+ Device device = init_layer_value[0]->device;
+ for (const NDArray& state : init_layer_value) {
+ CHECK(state->device.device_type == device.device_type &&
+ state->device.device_id == device.device_id)
+ << "The device type of all states should be the same.";
+ }
+ CHECK_EQ(f_gets.size(), init_layer_value.size())
+ << "The number of state getters should be the same as the number of
states per layer, "
+ << "but got " << f_gets.size() << " and " << init_layer_value.size()
<< " respectively.";
+ CHECK_EQ(f_sets.size(), init_layer_value.size())
+ << "The number of state setters should be the same as the number of
states per layer, "
+ << "but got " << f_sets.size() << " and " << init_layer_value.size()
<< " respectively.";
+ ObjectPtr<RNNStateImpObj> n =
+ make_object<RNNStateImpObj>(num_layers, reserved_num_seqs,
max_history, device,
+ std::move(f_gets), std::move(f_sets),
init_layer_value);
+ return RNNState(std::move(n));
+ });
+
+} // namespace relax_vm
+} // namespace runtime
+} // namespace tvm
diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py
b/tests/python/relax/test_runtime_builtin_rnn_state.py
new file mode 100644
index 0000000000..28f370bca0
--- /dev/null
+++ b/tests/python/relax/test_runtime_builtin_rnn_state.py
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring,
+from typing import Sequence, Union
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import dlight as dl
+from tvm import tir
+from tvm.runtime import ShapeTuple
+from tvm.script import tir as T
+
+# pylint: disable=invalid-name
+
+np_zero = np.full((16, 16), 0.0, "float16")
+np_one = np.full((32, 32), 1.0, "float32")
+np_two = np.full((16, 16), 2.0, "float16")
+np_three = np.full((32, 32), 3.0, "float32")
+
+reserved_nseq = 4
+max_history = 4
+num_layers = 1
+device = tvm.cuda()
+# Note that kernels in this test file cannot support 1-dim states.
+states = [((16, 16), "float16"), ((32, 32), "float32")]
+
+f_clear = None
+f_add_sequence = None
+f_remove_sequence = None
+f_fork_sequence = None
+f_popn = None
+f_begin_forward = None
+f_end_forward = None
+f_get = None
+f_set = None
+f_debug_get = None
+
+f_tir_gets = []
+f_tir_sets = []
+
+# pylint: enable=invalid-name
+
+
+def set_global_func():
+ global f_clear, f_add_sequence, f_remove_sequence, f_fork_sequence, f_popn
+ global f_begin_forward, f_end_forward, f_get, f_set, f_debug_get
+ global f_tir_gets, f_tir_sets
+
+ f_clear = tvm.get_global_func("vm.builtin.kv_state_clear")
+ f_add_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
+ f_remove_sequence =
tvm.get_global_func("vm.builtin.kv_state_remove_sequence")
+ f_fork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence")
+ f_popn = tvm.get_global_func("vm.builtin.kv_state_popn")
+ f_begin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
+ f_end_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward")
+ f_get = tvm.get_global_func("vm.builtin.rnn_state_get")
+ f_set = tvm.get_global_func("vm.builtin.rnn_state_set")
+ f_debug_get = tvm.get_global_func("vm.builtin.rnn_state_debug_get")
+
+ target = tvm.target.Target("cuda")
+
+ def _build(tir_func):
+ mod = tvm.IRModule({"main": tir_func})
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) # pylint:
disable=not-callable
+ f = tvm.build(mod["main"], target=target)
+ return f.entry_func
+
+ _f_tir_gets, _f_tir_sets = [], []
+ for state in states:
+ shape, dtype = state
+ _f_tir_gets.append(_build(rnn_state_get(shape, dtype)))
+ _f_tir_sets.append(_build(rnn_state_set(shape, dtype)))
+
+ f_tir_gets = _f_tir_gets
+ f_tir_sets = _f_tir_sets
+
+
+def create_rnn_state():
+ f_create = tvm.get_global_func("vm.builtin.rnn_state_create")
+ init_values = [tvm.nd.array(np_zero, device=device), tvm.nd.array(np_one,
device=device)]
+ return f_create(num_layers, reserved_nseq, max_history, f_tir_gets,
f_tir_sets, init_values)
+
+
[email protected]
+def rnn_state():
+ set_global_func()
+ return create_rnn_state()
+
+
+def verify_state(state, seq_ids, expected_values):
+ layer_id = 0
+ for seq_id in seq_ids:
+ for state_id, expected_value in enumerate(expected_values[seq_id]):
+ state_value = f_debug_get(state, layer_id, state_id, seq_id)
+ tvm.testing.assert_allclose(state_value.numpy(), expected_value)
+
+
[email protected]_cuda
+def test_rnn_state_get(rnn_state): # pylint: disable=redefined-outer-name
+ state = rnn_state
+ f_clear(state)
+ f_add_sequence(state, 0)
+ f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+ tvm_nd_0 = tvm.nd.array(np.empty((1, 16, 16), "float16"), device=device)
+ tvm_nd_1 = tvm.nd.array(np.empty((1, 32, 32), "float32"), device=device)
+ f_get(state, 0, 0, tvm_nd_0)
+ f_get(state, 0, 1, tvm_nd_1)
+ f_end_forward(state)
+ tvm.testing.assert_allclose(tvm_nd_0.numpy(), np.zeros((1, 16, 16),
"float16"))
+ tvm.testing.assert_allclose(tvm_nd_1.numpy(), np.ones((1, 32, 32),
"float32"))
+
+
[email protected]_cuda
+def test_rnn_state_set(rnn_state): # pylint: disable=redefined-outer-name
+ state = rnn_state
+ f_clear(state)
+ for seq_id in range(3):
+ f_add_sequence(state, seq_id)
+ f_begin_forward(state, ShapeTuple([0, 2]), ShapeTuple([1, 1]))
+
+ f_set(state, 0, 0, tvm.nd.array(np.full((2, 16, 16), 2.0, "float16"),
device=device))
+ f_set(state, 0, 1, tvm.nd.array(np.full((2, 32, 32), 3.0, "float32"),
device=device))
+ f_end_forward(state)
+
+ expected_values = [[np_two, np_three], [np_zero, np_one], [np_two,
np_three]]
+ verify_state(state, [0, 1, 2], expected_values)
+
+
[email protected]_cuda
+def test_rnn_state_popn(rnn_state): # pylint: disable=redefined-outer-name
+ state = rnn_state
+ f_clear(state)
+
+ f_add_sequence(state, 0)
+ f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+ f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device))
+ f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32),
device=device))
+ f_end_forward(state)
+
+ verify_state(state, [0], [[np_two, np_three]])
+ f_popn(state, 0, 1)
+ verify_state(state, [0], [[np_zero, np_one]])
+ with pytest.raises(tvm.error.TVMError):
+ f_popn(state, 0, 1) # no available history to pop
+
+
[email protected]_cuda
+def test_rnn_state_fork_sequence(rnn_state): # pylint:
disable=redefined-outer-name
+ state = rnn_state
+ f_clear(state)
+
+ f_add_sequence(state, 0)
+ f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+ f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device))
+ f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32),
device=device))
+ f_end_forward(state)
+ f_fork_sequence(state, 0, 1)
+ verify_state(state, [0, 1], [[np_two, np_three], [np_two, np_three]])
+ # Verify popn for the forked sequence
+ f_popn(state, 1, 1)
+ verify_state(state, [0, 1], [[np_two, np_three], [np_zero, np_one]])
+
+
+def rnn_state_get(
+ shape: Sequence[int],
+ dtype: str,
+):
+ # fmt: off
+ @T.prim_func
+ def _rnn_state_get(
+ var_storage: T.handle,
+ var_seq_slot_ids: T.handle,
+ var_history_slot_ids: T.handle,
+ var_output: T.handle,
+ ):
+ batch_size = T.int32(is_size_var=True)
+
+ storage = T.match_buffer(var_storage, (reserved_nseq, max_history,
*shape), dtype)
+ seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), "int32")
+ history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,),
"int32")
+ output = T.match_buffer(var_output, (batch_size, *shape), dtype)
+
+ for i in range(batch_size):
+ for s in T.grid(*shape):
+ with T.block("copy"):
+ vi, *vs = T.axis.remap("S" * (len(shape) + 1), [i, *s])
+ seq_id: T.int32 = seq_slot_ids[vi]
+ history_id: T.int32 = history_slot_ids[vi]
+ # The following line is equivalent to:
+ # `output[vi, *vs] = storage[seq_id, history_id, *vs]`
+ # However, unpacking operator in subscript requires Python
3.11 or newer
+ T.buffer_store(
+ output, T.BufferLoad(storage, [seq_id, history_id,
*vs]), [vi, *vs]
+ )
+ # fmt: on
+ return _rnn_state_get
+
+
+def rnn_state_set(
+ shape: Sequence[Union[int, tir.Var]],
+ dtype: str,
+):
+ # fmt: off
+ @T.prim_func
+ def _rnn_state_set(
+ var_storage: T.handle,
+ var_seq_slot_ids: T.handle,
+ var_history_slot_ids: T.handle,
+ var_data: T.handle,
+ ):
+ batch_size = T.int32(is_size_var=True)
+
+ storage = T.match_buffer(var_storage, (reserved_nseq, max_history,
*shape), dtype)
+ seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), "int32")
+ history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,),
"int32")
+ data = T.match_buffer(var_data, (batch_size, *shape), dtype)
+
+ for i in range(batch_size):
+ for s in T.grid(*shape):
+ with T.block("copy"):
+ vi, *vs = T.axis.remap("S" * (len(shape) + 1), [i, *s])
+ seq_id: T.int32 = seq_slot_ids[vi]
+ history_id: T.int32 = (history_slot_ids[vi] + 1) % T.cast(
+ max_history, "int32"
+ )
+ # The following line is equivalent to:
+ # `storage[seq_id, history_id, *vs] = data[vi, *vs]`
+ # However, unpacking operator in subscript requires Python
3.11 or newer
+ T.buffer_store(
+ storage, T.BufferLoad(data, [vi, *vs]), [seq_id,
history_id, *vs]
+ )
+
+ # fmt: on
+
+ return _rnn_state_set
+
+
+if __name__ == "__main__":
+ set_global_func()
+ rnn_state = create_rnn_state()
+ test_rnn_state_get(rnn_state)
+ test_rnn_state_set(rnn_state)
+ test_rnn_state_popn(rnn_state)
+ test_rnn_state_fork_sequence(rnn_state)