This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ef478b486 [Relax][Runtime] RNNState for Space State Models (#16568)
3ef478b486 is described below

commit 3ef478b4864c110c33fde937b9f6b8e604e957d3
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Feb 21 21:32:23 2024 +0800

    [Relax][Runtime] RNNState for Space State Models (#16568)
    
    * [Relax][Runtime] RNNState for Space State Models
    
    This commit adds the RNNState class to the Relax VM, similar to the
    PagedKVCache, for space state models like RWKV and mamba
    
    * refactor
---
 src/runtime/relax_vm/kv_state.cc                   |  80 ++++
 src/runtime/relax_vm/{kv_cache.h => kv_state.h}    | 118 +++--
 src/runtime/relax_vm/lm_support.cc                 |  11 +-
 src/runtime/relax_vm/paged_kv_cache.cc             |  41 +-
 src/runtime/relax_vm/rnn_state.cc                  | 487 +++++++++++++++++++++
 .../python/relax/test_runtime_builtin_rnn_state.py | 262 +++++++++++
 6 files changed, 947 insertions(+), 52 deletions(-)

diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc
new file mode 100644
index 0000000000..7c86e96ec6
--- /dev/null
+++ b/src/runtime/relax_vm/kv_state.cc
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "kv_state.h"
+
+#include <utility>
+
+namespace tvm {
+namespace runtime {
+namespace relax_vm {
+
+// Register Object Type
+TVM_REGISTER_OBJECT_TYPE(KVStateObj);
+TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
+TVM_REGISTER_OBJECT_TYPE(RNNStateObj);
+
+// KV State base methods
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method<KVState>(&KVStateObj::Clear);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence")
+    .set_body_method<KVState>(&KVStateObj::AddSequence);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence")
+    .set_body_method<KVState>(&KVStateObj::RemoveSequence);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
+    .set_body_method<KVState>(&KVStateObj::ForkSequence);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
+    .set_body_method<KVState>(&KVStateObj::BeginForward);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
+    .set_body_method<KVState>(&KVStateObj::EndForward);
+
+// Attention KV Cache methods
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
+    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
+    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention")
+    .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
+                       double attn_score_scaling_factor, NDArray q_data, 
NDArray k_data,
+                       NDArray v_data, NDArray o_data) {
+      kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data), 
std::move(v_data),
+                          NullOpt, std::move(o_data), 
attn_score_scaling_factor);
+    });
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
+    .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
+                       double attn_score_scaling_factor, NDArray qkv_data, 
NDArray o_data) {
+      kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, 
std::move(o_data),
+                                      attn_score_scaling_factor);
+    });
+
+// RNN State methods
+TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
+TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
+    .set_body_typed([](RNNState state, int64_t layer_id, int64_t state_id, 
NDArray data) {
+      state->Set(layer_id, state_id, data);
+      return state;
+    });
+TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get")
+    .set_body_method<RNNState>(&RNNStateObj::DebugGet);
+
+}  // namespace relax_vm
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/relax_vm/kv_cache.h b/src/runtime/relax_vm/kv_state.h
similarity index 74%
rename from src/runtime/relax_vm/kv_cache.h
rename to src/runtime/relax_vm/kv_state.h
index 82e32b3af5..5f824a84b1 100644
--- a/src/runtime/relax_vm/kv_cache.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -16,30 +16,29 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#ifndef TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
-#define TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
+#ifndef TVM_RUNTIME_RELAX_VM_KV_STATE_H_
+#define TVM_RUNTIME_RELAX_VM_KV_STATE_H_
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/registry.h>
 
+#include "tvm/runtime/object.h"
+
 namespace tvm {
 namespace runtime {
 namespace relax_vm {
 
-/*!
- * \brief The base class of attention KV cache for efficient
- * k/v data management and attention computation.
- */
-class AttentionKVCache : public Object {
+/*! \brief The base class of attention KV cache and rnn state. */
+class KVStateObj : public Object {
  public:
-  /*! \brief Reset the KV cache. */
+  /*! \brief Reset the KV State. */
   virtual void Clear() = 0;
 
   /************** Sequence Management **************/
 
   /*!
-   * \brief Add a new sequence with empty K/V data in the cache.
+   * \brief Add a new sequence with empty K/V state in the cache.
    * Check if the validity of the input sequence id.
    * \param seq_id The id of the new sequence to be added.
    * \throws Error if the given sequence id is not valid.
@@ -47,15 +46,15 @@ class AttentionKVCache : public Object {
   virtual void AddSequence(int64_t seq_id) = 0;
 
   /*!
-   * \brief Remove a sequence and its K/V data from the KV cache.
+   * \brief Remove a sequence and its K/V state from the KV cache.
    * \param seq_id The sequence to remove from cache.
    * \throws Error if the given sequence id is not valid.
    */
   virtual void RemoveSequence(int64_t seq_id) = 0;
 
   /*!
-   * \brief Fork the K/V data of parent sequence to the child sequence.
-   * After the fork, the child sequence has K/V data of the parent
+   * \brief Fork the K/V state of parent sequence to the child sequence.
+   * After the fork, the child sequence has K/V state of the parent
    * sequence.
    * \param parent_seq_id The parent (source) of the fork.
    * \param child_seq_id The child (destination) of the fork.
@@ -73,18 +72,6 @@ class AttentionKVCache : public Object {
    */
   virtual void PopN(int64_t seq_id, int32_t n) = 0;
 
-  /************** Raw Info Query **************/
-
-  /*!
-   * \brief Get the number of available pages in the KV cache.
-   * When the underlying KV cache implementation is not
-   * paged KV cache, the function falls back to return the
-   * number of remaining size (in terms of number of tokens).
-   */
-  virtual int32_t GetNumAvailablePages() const = 0;
-
-  /************** Attention **************/
-
   /*!
    * \brief Mark the start of the forward function with the ids of
    * the sequences and the sequence length to forward for each
@@ -109,6 +96,34 @@ class AttentionKVCache : public Object {
    */
   virtual void EndForward() = 0;
 
+  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+  static constexpr const char* _type_key = "relax.vm.KVState";
+  TVM_DECLARE_BASE_OBJECT_INFO(KVStateObj, Object)
+};
+
+class KVState : public ObjectRef {
+ public:
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(KVState, ObjectRef, KVStateObj);
+};
+
+/*!
+ * \brief The base class of attention KV cache for efficient
+ * k/v data management and attention computation.
+ */
+class AttentionKVCacheObj : public KVStateObj {
+ public:
+  /************** Raw Info Query **************/
+
+  /*!
+   * \brief Get the number of available pages in the KV cache.
+   * When the underlying KV cache implementation is not
+   * paged KV cache, the function falls back to return the
+   * number of remaining size (in terms of number of tokens).
+   */
+  virtual int32_t GetNumAvailablePages() const = 0;
+
+  /************** Attention **************/
+
   /*!
    * \brief Compute attention with the given Q/K/V data at the specified
    * layer with regard to the previously reserved append lengths.
@@ -197,10 +212,63 @@ class AttentionKVCache : public Object {
    * \param v_data The V data to set in layout elaborated above.
    */
   virtual void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data, 
NDArray v_data) = 0;
+
+  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+  static constexpr const char* _type_key = "relax.vm.AttentionKVCache";
+  TVM_DECLARE_BASE_OBJECT_INFO(AttentionKVCacheObj, KVStateObj);
+};
+
+class AttentionKVCache : public KVState {
+ public:
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCache, KVState, 
AttentionKVCacheObj);
+};
+
+/*!
+ * \brief The base class of RNN State for efficient
+ * State data management and attention computation.
+ */
+class RNNStateObj : public KVStateObj {
+ public:
+  /************** Interaction **************/
+  /*!
+   * \brief Get the State data for the specified sequence.
+   * \param layer_id The model layer where the state is set.
+   * \param state_id The state id within the layer.
+   * \param o_data The output data to be fetched.
+   * \return The array of State data, each element corresponds to a state.
+   * \throws Error if the given sequence id is not valid.
+   */
+  virtual void Get(int64_t layer_id, int64_t state_id, NDArray o_data) = 0;
+
+  /*!
+   * \brief Set the State data for the specified sequence.
+   * \param layer_id The model layer where the state is set.
+   * \param state_id The state id within the layer.
+   * \param data The data to be set.
+   * \throws Error if the given sequence id is not valid.
+   */
+  virtual void Set(int64_t layer_id, int64_t state_id, NDArray data) = 0;
+
+  /*!
+   * \brief Fetch the compact rnn state data of the given sequence.
+   * \param layer_id The model layer where the state is set.
+   * \param state_id The state id within the layer.
+   * \param seq_id The sequence whose state data is to be fetched.
+   */
+  virtual NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) 
= 0;
+
+  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+  static constexpr const char* _type_key = "relax.vm.RNNState";
+  TVM_DECLARE_BASE_OBJECT_INFO(RNNStateObj, KVStateObj);
+};
+
+class RNNState : public KVState {
+ public:
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RNNState, KVState, RNNStateObj);
 };
 
 }  // namespace relax_vm
 }  // namespace runtime
 }  // namespace tvm
 
-#endif  // TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
+#endif  // TVM_RUNTIME_RELAX_VM_KV_STATE_H_
diff --git a/src/runtime/relax_vm/lm_support.cc 
b/src/runtime/relax_vm/lm_support.cc
index fccff2cecd..cfb78006d7 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -59,7 +59,7 @@ namespace relax_vm {
 /*!
  * \brief An object representing an attention kv cache.
  */
-class AttentionKVCacheObj : public Object {
+class AttentionKVCacheLegacyObj : public Object {
  public:
   /*!
    * \brief Underlying support data.
@@ -227,7 +227,7 @@ class AttentionKVCacheObj : public Object {
 
   static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
   static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy";
-  TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheObj, Object);
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheLegacyObj, Object);
 };
 
 /*! \brief reference to closure. */
@@ -239,7 +239,7 @@ class AttentionKVCacheLegacy : public ObjectRef {
    */
   static AttentionKVCacheLegacy Create(NDArray init_data, ShapeTuple 
reserve_shape,
                                        int init_fill_count) {
-    auto n = make_object<AttentionKVCacheObj>();
+    auto n = make_object<AttentionKVCacheLegacyObj>();
     n->data = NDArray::Empty(reserve_shape, init_data->dtype, 
init_data->device);
     n->fill_count = 0;
     n->Append(init_data);
@@ -250,10 +250,11 @@ class AttentionKVCacheLegacy : public ObjectRef {
     return AttentionKVCacheLegacy(n);
   }
 
-  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef, 
AttentionKVCacheObj);
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef,
+                                        AttentionKVCacheLegacyObj);
 };
 
-TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
+TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheLegacyObj);
 
 //-------------------------------------------------
 //  Register runtime functions
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 70fa3daee7..f848ed2490 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -29,7 +29,7 @@
 #include <utility>
 #include <vector>
 
-#include "kv_cache.h"
+#include "kv_state.h"
 
 namespace tvm {
 namespace runtime {
@@ -183,7 +183,7 @@ enum class RoPEMode : int {
  *     After calling `EndForward`, it is required to call `BeginForward`
  *     before calling any `Attention`.
  */
-class PagedAttentionKVCacheObj : public AttentionKVCache {
+class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
  private:
   /********************* Configuration *********************/
 
@@ -810,7 +810,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
 
   static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
   static constexpr const char* _type_key = "relax.vm.PagedAttentionKVCache";
-  TVM_DECLARE_FINAL_OBJECT_INFO(PagedAttentionKVCacheObj, Object);
+  TVM_DECLARE_FINAL_OBJECT_INFO(PagedAttentionKVCacheObj, AttentionKVCacheObj);
 
  private:
   /*! \brief Get a new free page and return its id. */
@@ -1157,11 +1157,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCache 
{
   }
 };
 
-class PagedAttentionKVCache : public ObjectRef {
- public:
-  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PagedAttentionKVCache, ObjectRef, 
PagedAttentionKVCacheObj);
-};
-
 TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
 
 //-------------------------------------------------
@@ -1199,7 +1194,7 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
           std::move(f_attention_decode_begin_forward), 
std::move(f_attention_decode_end_forward),
           std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_rotary_inplace),
           std::move(f_debug_get_kv));
-      return PagedAttentionKVCache(std::move(n));
+      return AttentionKVCache(std::move(n));
     });
 
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
@@ -1224,38 +1219,40 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
           NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,                
  //
           std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_rotary_inplace),
           std::move(f_debug_get_kv));
-      return PagedAttentionKVCache(std::move(n));
+      return AttentionKVCache(std::move(n));
     });
 
+// Keep the following global functions for backward compatibility.
+// TODO(tvm-team): Remove these global functions in the future.
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear")
-    .set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::Clear);
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Clear);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence")
-    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::AddSequence);
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::AddSequence);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_remove_sequence")
-    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::RemoveSequence);
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::RemoveSequence);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_fork_sequence")
-    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::ForkSequence);
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::ForkSequence);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_popn")
-    .set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::PopN);
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::PopN);
 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_num_available_pages")
-    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::GetNumAvailablePages);
+    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_begin_forward")
-    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::BeginForward);
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::BeginForward);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_end_forward")
-    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::EndForward);
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EndForward);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions")
-    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::GetQueryPositions);
+    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv")
-    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::DebugGetKV);
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")
-    .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
+    .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
                        double attn_score_scaling_factor, NDArray q_data, 
NDArray k_data,
                        NDArray v_data, NDArray o_data) {
       kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data), 
std::move(v_data),
                           NullOpt, std::move(o_data), 
attn_score_scaling_factor);
     });
 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv")
-    .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
+    .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
                        double attn_score_scaling_factor, NDArray qkv_data, 
NDArray o_data) {
       kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, 
std::move(o_data),
                                       attn_score_scaling_factor);
diff --git a/src/runtime/relax_vm/rnn_state.cc 
b/src/runtime/relax_vm/rnn_state.cc
new file mode 100644
index 0000000000..09873ba5f7
--- /dev/null
+++ b/src/runtime/relax_vm/rnn_state.cc
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/runtime/relax_vm/rnn_state.cc
+ * \brief Runtime RNN state object for space state models.
+ */
+
+#include <cstdint>
+#include <vector>
+
+#include "kv_state.h"
+
+namespace tvm {
+namespace runtime {
+namespace relax_vm {
+
+//-----------------------------------------------------------------------------
+// We keep the implementation private as they may subject to future changes.
+//
+// Users can interact with it through the runtime API function calls
+//-----------------------------------------------------------------------------
+
+class RNNStateImpObj : public RNNStateObj {
+ private:
+  /********************* Data Structures *********************/
+
+  /*!
+   * \brief The sequence structure in paged KV cache with common prefix 
support.
+   * Each sequence contains one or more blocks to support common prefix.
+   */
+  struct Sequence {
+    /*! \brief The total sequence length of the sequence. */
+    int64_t seq_length = 0;
+    /*! \brief The available history length for rolling back. */
+    int64_t available_history_num = 0;
+    /*! \brief The index of history slot in the storage. */
+    int64_t history_slot_id = 0;
+    /*! \brief The index of seq slot in the storage. */
+    int64_t seq_slot_id;
+
+    /*! \brief Constructor. */
+    explicit Sequence(int64_t seq_slot_id) : seq_slot_id(seq_slot_id) {}
+
+    static Sequence Fork(const Sequence& parent, int64_t seq_slot_id) {
+      Sequence child = parent;
+      child.seq_slot_id = seq_slot_id;
+      return child;
+    }
+  };
+
+  /********************* Configuration *********************/
+
+  /*! \brief The number of layers in the model. */
+  const int64_t num_layers_;
+  /*! \brief The max number of sequences in the storage. */
+  const int64_t reserved_num_seqs_;
+  /*! \brief The number of states per layer. */
+  const int64_t num_states_per_layer_;
+  /*! \brief The max history length for rolling back. */
+  const int64_t max_history_ = 1;
+  /*!
+   * \brief The init value for ALL layer in the storage.
+   * The array has `num_states_per_layer_` NDArrays
+   */
+  const Array<NDArray> init_layer_value_;
+
+  /*! \brief We fix int32 to be the index dtype of auxiliary data. */
+  const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1));
+
+  /******************* Storage Structures *******************/
+
+  /*!
+   * \brief The storages of space state models.
+   * The array has `num_layers * num_states_per_layer_` NDArrays,
+   * each of them has layout `(num_seq, max_history, state_size)`.
+   * \note As `num_states_per_layer_` may vary for different dtype and shape,
+   * we use a 2D array to store the NDArrays for each layer.
+   */
+  Array<Array<NDArray>> storages_;
+  /*! \brief The list of ids of released seq slot for reuse. */
+  std::vector<int64_t> free_slot_ids_;
+  /*! \brief The mapping from sequence ids to sequences. */
+  std::unordered_map<int64_t, Sequence> seq_map_;
+
+  /****************** Auxiliary Arrays on Host ******************/
+
+  /*! \brief The batch size of the current round of forwarding. */
+  int64_t cur_batch_size_;
+  /*! \brief The append lengths of the sequences in the current round of 
forwarding. */
+  IntTuple cur_append_lengths_;
+  /*! \brief The sequence ids of the current round of forwarding. */
+  IntTuple cur_seq_ids_;
+
+  /**************** Auxiliary Arrays on Device *****************/
+
+  /*!
+   * \brief A boolean flag indicating if the auxiliary arrays are dirty.
+   * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked.
+   */
+  bool dirty_aux_data_device_ = false;
+  /*! \brief The device array of the sequence ids. */
+  NDArray seq_slot_ids_device_;
+  /*!
+   * \brief The view of the device array of the sequence ids.
+   * The view is used to reuse the memory but with different shape.
+   */
+  NDArray seq_slot_ids_view_;
+  /*! \brief The device array of the history slot ids. */
+  NDArray history_slot_ids_device_;
+  /*!
+   * \brief The view of the device array of the history slot ids.
+   * The view is used to reuse the memory but with different shape.
+   */
+  NDArray history_slot_ids_view_;
+
+  /******************* Interaction Functions *******************/
+
+  /*!
+   * \brief The function to get the state data from the storage.
+   * The function signature is `f_get_(state, seq_slot_ids, history_slot_ids, 
out_data)`.
+   * and return the contiguous batched state data.
+   * \note Each state data per layer may have different dtype and shape, so we 
use a
+   * different function for each state data.
+   */
+  Array<PackedFunc> f_gets_;
+  /*!
+   * \brief The function to set the state data to the storage.
+   * The function signature is `f_set_(state, seq_slot_ids, history_slot_ids, 
data, max_history)`.
+   * where `state` is the storage NDArray, `seq_slot_ids` and 
`history_slot_ids` are
+   * 1-D int32 arrays of the same length as the batch size, and `data` is the 
input data.
+   * \note The `history_slot_ids` is the slot of this round, but we need to 
write to the
+   * slot of the next round.
+   * \note Each state data per layer may have different dtype and shape, so we 
use a
+   * different function for each state data.
+   */
+  Array<PackedFunc> f_sets_;
+
+ public:
+  /*! \brief Constructor. Take the cache configuration and initialize the 
NDArrays. */
+  explicit RNNStateImpObj(int64_t num_layers,         //
+                          int64_t reserved_num_seqs,  //
+                          int64_t max_history,        //
+                          DLDevice device,            //
+                          Array<PackedFunc> f_gets,   //
+                          Array<PackedFunc> f_sets,   //
+                          Array<NDArray> init_layer_value)
+      : num_layers_(num_layers),
+        reserved_num_seqs_(reserved_num_seqs),
+        num_states_per_layer_(init_layer_value.size()),
+        max_history_(max_history),
+        init_layer_value_(init_layer_value),
+        f_gets_(std::move(f_gets)),
+        f_sets_(std::move(f_sets)) {
+    // Allocate the storage for the space state models.
+    storages_.reserve(num_layers_);
+    for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
+      Array<NDArray> layer_storages;
+      layer_storages.reserve(num_states_per_layer_);
+      for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id) 
{
+        ShapeTuple state_shape = init_layer_value[state_id].Shape();
+        std::vector<ShapeTupleObj::index_type> storage_shape = 
{reserved_num_seqs, max_history};
+        storage_shape.insert(storage_shape.end(), state_shape.begin(), 
state_shape.end());
+        NDArray state_storage =
+            NDArray::Empty(storage_shape, 
init_layer_value[state_id].DataType(), device);
+        layer_storages.push_back(state_storage);
+      }
+      storages_.push_back(layer_storages);
+    }
+
+    CHECK_GT(max_history_, 0) << "At least 1 history slot to store the current 
state";
+
+    // Allocate the auxiliary arrays on device.
+    seq_slot_ids_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, 
device);
+    history_slot_ids_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, 
device);
+
+    Clear();
+  }
+
+  /*! \brief Reset the KV cache. */
+  void Clear() final {
+    seq_map_.clear();
+    ICHECK(!storages_.empty());
+    free_slot_ids_.clear();
+    for (int64_t slot_id = reserved_num_seqs_ - 1; slot_id >= 0; --slot_id) {
+      free_slot_ids_.push_back(slot_id);
+    }
+    dirty_aux_data_device_ = false;
+  }
+
+  /************** Interaction **************/
+
+  void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) {
+    CHECK_EQ(seq_ids.size(), append_lengths.size())
+        << "The seq_ids size (" << seq_ids.size() << ") and append_lengths 
size ("
+        << append_lengths.size() << ") mismatch.";
+    cur_batch_size_ = seq_ids.size();
+    cur_append_lengths_ = append_lengths;
+    cur_seq_ids_ = seq_ids;
+
+    if (dirty_aux_data_device_) {
+      SyncAuxArrayToDevice();
+    }
+  }
+
+  void EndForward() final {
+    for (int64_t i = 0; i < cur_batch_size_; ++i) {
+      int64_t seq_id = cur_seq_ids_[i];
+      int64_t seq_length = cur_append_lengths_[i];
+      auto it = seq_map_.find(seq_id);
+      CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+                                  << "\" cannot be found in the space state 
storage.";
+      it->second.seq_length += seq_length;
+      if (seq_length > 1) {
+        // We cannot rollback the prefill input
+        it->second.available_history_num = 0;
+      } else {
+        it->second.available_history_num =
+            std::min(it->second.available_history_num + 1, max_history_ - 1);
+      }
+      it->second.history_slot_id = (it->second.history_slot_id + 1) % 
max_history_;
+    }
+    // TODO(Siyuan): We need to update history_slot_id_device_ (on device) as 
well.
+    // There are two ways to do this:
+    // 1. Update history_slot_id_device_ on device directly through a explict 
kernel
+    // 2. Update history_slot_id on host and then sync to device.
+    // We choose the second way for now for convenience. But the first way is 
more efficient.
+    dirty_aux_data_device_ = true;
+  }
+
+  void Get(int64_t layer_id, int64_t state_id, NDArray o_data) final {
+    // The auxiliary data structure on device must have been synchronized.
+    CHECK(!dirty_aux_data_device_)
+        << "The auxiliary arrays are not synchronized to device. Please call "
+           "`BeginForward` to synchronize before calling `Get`.";
+    ICHECK(cur_batch_size_ == static_cast<int64_t>(cur_seq_ids_.size()))
+        << "The batch size is not consistent with the number of sequence ids.";
+    CHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater 
than 0.";
+    // TODO(siyuan): support zero-copy when seq_len is one
+    // Copy the state data to the return array.
+    NDArray state = storages_[layer_id][state_id];
+    f_gets_[state_id](state, seq_slot_ids_view_, history_slot_ids_view_, 
o_data);
+  }
+
+  void Set(int64_t layer_id, int64_t state_id, NDArray data) final {
+    // The auxiliary data structure on device must have been synchronized.
+    CHECK(!dirty_aux_data_device_)
+        << "The auxiliary arrays are not synchronized to device. Please call "
+           "`BeginForward` to synchronize before calling `Set`.";
+    ICHECK(cur_batch_size_ == static_cast<int64_t>(cur_seq_ids_.size()))
+        << "The batch size is not consistent with the number of sequence ids.";
+    CHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater 
than 0.";
+
+    NDArray state = storages_[layer_id][state_id];
+    f_sets_[state_id](state, seq_slot_ids_view_, history_slot_ids_view_, data);
+  }
+
+  NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) {
+    auto it = seq_map_.find(seq_id);
+    CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+                                << "\" cannot be found in the space state 
storage.";
+    NDArray state = storages_[layer_id][state_id];
+    int64_t seq_slot_id = it->second.seq_slot_id;
+    int64_t history_slot_id = it->second.history_slot_id;
+
+    std::vector<int64_t> shape{state.Shape().begin() + 2, state.Shape().end()};
+    NDArray result = NDArray::Empty(shape, state->dtype, state->device);
+    DLTensor copy_src = GetStatePtrBySeqHistory(layer_id, state_id, 
seq_slot_id, history_slot_id);
+    DLTensor copy_dst = *result.operator->();
+
+    NDArray::CopyFromTo(&copy_src, &copy_dst);
+    return result;
+  }
+
+  /************** Sequence Management **************/
+
+  void AddSequence(int64_t seq_id) final {
+    CHECK(seq_map_.find(seq_id) == seq_map_.end())
+        << "The sequence \"" << seq_id << "\" is already in the space state 
storage.";
+    int64_t seq_slot_id = GetFreeSlot();
+    seq_map_.insert({seq_id, Sequence(seq_slot_id)});
+
+    // Initialize the state data with the init value.
+    for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
+      for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id) 
{
+        DLTensor dst =
+            GetStatePtrBySeqHistory(layer_id, state_id, seq_slot_id, 
/*history_slot_id=*/0);
+        NDArray init = init_layer_value_[state_id];
+        NDArray::CopyFromTo(init.operator->(), &dst);
+      }
+    }
+
+    dirty_aux_data_device_ = true;
+  }
+
+  void RemoveSequence(int64_t seq_id) final {
+    auto it = seq_map_.find(seq_id);
+    CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+                                << "\" cannot be found in the space state 
storage.";
+
+    free_slot_ids_.push_back(it->second.seq_slot_id);
+    seq_map_.erase(it);
+
+    dirty_aux_data_device_ = true;
+  }
+
+  void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
+    auto parent_it = seq_map_.find(parent_seq_id);
+    CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << 
parent_seq_id
+                                       << "\" cannot be found in space state 
storage.";
+    CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
+        << "The child sequence \"" << child_seq_id << "\" is already in the 
space state storage.";
+
+    // Create a child block with the parent block pointer.
+    int64_t child_slot_id = GetFreeSlot();
+    seq_map_.insert({child_seq_id, Sequence::Fork(parent_it->second, 
child_slot_id)});
+
+    // Copy the parent state data to the child state data.
+    int64_t parent_slot_id = parent_it->second.seq_slot_id;
+    for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
+      for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id) 
{
+        DLTensor copy_src = GetStatePtrBySeq(layer_id, state_id, 
parent_slot_id);
+        DLTensor copy_dst = GetStatePtrBySeq(layer_id, state_id, 
child_slot_id);
+        NDArray::CopyFromTo(&copy_src, &copy_dst);
+      }
+    }
+    dirty_aux_data_device_ = true;
+  }
+
+  void PopN(int64_t seq_id, int32_t n) final {
+    auto it = seq_map_.find(seq_id);
+    CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+                                << "\" cannot be found in space state.";
+    CHECK_GE(n, 0) << "The length of rolling back " << n << " cannot be 
negative.";
+    CHECK_LE(n, it->second.available_history_num)
+        << "The sequence only has " << it->second.available_history_num
+        << " available history in the space state storage, while the length of 
rollback is " << n
+        << " which exceeds the sequence length.";
+
+    it->second.seq_length -= n;
+    it->second.available_history_num -= n;
+    it->second.history_slot_id = (it->second.history_slot_id - n + 
max_history_) % max_history_;
+    dirty_aux_data_device_ = true;
+  }
+
+ private:
+  /*! \brief Get a new free block and return its index. */
+  int32_t GetFreeSlot() {
+    CHECK(!free_slot_ids_.empty()) << "The Sequence slot is full, cannot 
accept new sequence.";
+    int32_t seq_slot_id = free_slot_ids_.back();
+    free_slot_ids_.pop_back();
+    return seq_slot_id;
+  }
+
+  DLTensor GetStatePtrBySeqHistory(int64_t layer_id, int64_t state_id, int64_t 
seq_slot_id,
+                                   int64_t history_slot_id) {
+    NDArray state = storages_[layer_id][state_id];
+    int64_t state_size = 1;
+    for (int64_t i = 2; i < state->ndim; ++i) {
+      state_size *= state->shape[i];
+    }
+    int64_t elem_offset = (seq_slot_id * max_history_ + history_slot_id) * 
state_size;
+    // Create a new DLTensor with the same shape and dtype as the state.
+    DLTensor _state = *(state.operator->());
+    _state.byte_offset = elem_offset * state->dtype.bits / 8;
+    _state.ndim = state->ndim - 2;
+    _state.shape = const_cast<int64_t*>(_state.shape + 2);
+    return _state;
+  }
+
+  DLTensor GetStatePtrBySeq(int64_t layer_id, int64_t state_id, int64_t 
seq_slot_id) {
+    NDArray state = storages_[layer_id][state_id];
+    int64_t state_size = 1;
+    for (int64_t i = 1; i < state->ndim; ++i) {
+      state_size *= state->shape[i];
+    }
+    int64_t elem_offset = seq_slot_id * state_size;
+    // Create a new DLTensor with the same shape and dtype as the state.
+    DLTensor _state = *(state.operator->());
+    _state.byte_offset = elem_offset * state->dtype.bits / 8;
+    _state.ndim = state->ndim - 1;
+    _state.shape = const_cast<int64_t*>(_state.shape + 1);
+    return _state;
+  }
+
+  /*!
+   * \brief Synchronize auxiliary arrays to device.
+   * \note This method resets the dirty flag to false, and needs to be
+   * invoked before running attention computation on device.
+   */
+  void SyncAuxArrayToDevice() {
+    auto fcopy_from_vec = [](NDArray array, std::vector<int32_t> vec_data) {
+      DLTensor copy_dst = *array.operator->();
+      DLTensor copy_src;
+      copy_src.data = vec_data.data();
+      copy_src.device = Device{kDLCPU, 0};
+      copy_src.ndim = 1;
+      copy_src.dtype = array->dtype;
+      copy_src.shape = array->shape;
+      copy_src.strides = nullptr;
+      copy_src.byte_offset = 0;
+      NDArray::CopyFromTo(&copy_src, &copy_dst);
+    };
+
+    std::vector<int32_t> seq_slot_ids;
+    std::vector<int32_t> history_slot_ids;
+    seq_slot_ids.reserve(cur_batch_size_);
+    history_slot_ids.reserve(cur_batch_size_);
+    for (int64_t seq_id : cur_seq_ids_) {
+      auto it = seq_map_.find(seq_id);
+      CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id
+                                  << "\" cannot be found in the space state 
storage.";
+      const Sequence& seq = it->second;
+      seq_slot_ids.push_back(seq.seq_slot_id);
+      history_slot_ids.push_back(seq.history_slot_id);
+    }
+    seq_slot_ids_view_ = seq_slot_ids_device_.CreateView({cur_batch_size_}, 
dtype_aux_);
+    history_slot_ids_view_ = 
history_slot_ids_device_.CreateView({cur_batch_size_}, dtype_aux_);
+
+    fcopy_from_vec(seq_slot_ids_view_, seq_slot_ids);
+    fcopy_from_vec(history_slot_ids_view_, history_slot_ids);
+
+    // Reset the dirty flag to false.
+    dirty_aux_data_device_ = false;
+  }
+
+ public:
+  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+  static constexpr const char* _type_key = "relax.vm.RNNStateImp";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RNNStateImpObj, RNNStateObj);
+};
+
+TVM_REGISTER_OBJECT_TYPE(RNNStateImpObj);
+
+//-------------------------------------------------
+//  Register runtime functions
+//-------------------------------------------------
+
+TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_create")
+    .set_body_typed([](int64_t num_layers,         //
+                       int64_t reserved_num_seqs,  //
+                       int64_t max_history,        //
+                       Array<PackedFunc> f_gets,   //
+                       Array<PackedFunc> f_sets,   //
+                       Array<NDArray> init_layer_value) {
+      CHECK_GT(num_layers, 0) << "The number of layers should be greater than 
0.";
+      CHECK_GT(reserved_num_seqs, 0)
+          << "The number of reserved sequences should be greater than 0.";
+      CHECK_GE(max_history, 0) << "The maximum history length should be 
greater or equal than 0.";
+      CHECK_GT(init_layer_value.size(), 0)
+          << "The number of states per layer should be greater than 0.";
+      Device device = init_layer_value[0]->device;
+      for (const NDArray& state : init_layer_value) {
+        CHECK(state->device.device_type == device.device_type &&
+              state->device.device_id == device.device_id)
+            << "The device type of all states should be the same.";
+      }
+      CHECK_EQ(f_gets.size(), init_layer_value.size())
+          << "The number of state getters should be the same as the number of 
states per layer, "
+          << "but got " << f_gets.size() << " and " << init_layer_value.size() 
<< " respectively.";
+      CHECK_EQ(f_sets.size(), init_layer_value.size())
+          << "The number of state setters should be the same as the number of 
states per layer, "
+          << "but got " << f_sets.size() << " and " << init_layer_value.size() 
<< " respectively.";
+      ObjectPtr<RNNStateImpObj> n =
+          make_object<RNNStateImpObj>(num_layers, reserved_num_seqs, 
max_history, device,
+                                      std::move(f_gets), std::move(f_sets), 
init_layer_value);
+      return RNNState(std::move(n));
+    });
+
+}  // namespace relax_vm
+}  // namespace runtime
+}  // namespace tvm
diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py 
b/tests/python/relax/test_runtime_builtin_rnn_state.py
new file mode 100644
index 0000000000..28f370bca0
--- /dev/null
+++ b/tests/python/relax/test_runtime_builtin_rnn_state.py
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring,
+from typing import Sequence, Union
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import dlight as dl
+from tvm import tir
+from tvm.runtime import ShapeTuple
+from tvm.script import tir as T
+
+# pylint: disable=invalid-name
+
+np_zero = np.full((16, 16), 0.0, "float16")
+np_one = np.full((32, 32), 1.0, "float32")
+np_two = np.full((16, 16), 2.0, "float16")
+np_three = np.full((32, 32), 3.0, "float32")
+
+reserved_nseq = 4
+max_history = 4
+num_layers = 1
+device = tvm.cuda()
+# Note that kernels in this test file cannot support 1-dim states.
+states = [((16, 16), "float16"), ((32, 32), "float32")]
+
+f_clear = None
+f_add_sequence = None
+f_remove_sequence = None
+f_fork_sequence = None
+f_popn = None
+f_begin_forward = None
+f_end_forward = None
+f_get = None
+f_set = None
+f_debug_get = None
+
+f_tir_gets = []
+f_tir_sets = []
+
+# pylint: enable=invalid-name
+
+
+def set_global_func():
+    global f_clear, f_add_sequence, f_remove_sequence, f_fork_sequence, f_popn
+    global f_begin_forward, f_end_forward, f_get, f_set, f_debug_get
+    global f_tir_gets, f_tir_sets
+
+    f_clear = tvm.get_global_func("vm.builtin.kv_state_clear")
+    f_add_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
+    f_remove_sequence = 
tvm.get_global_func("vm.builtin.kv_state_remove_sequence")
+    f_fork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence")
+    f_popn = tvm.get_global_func("vm.builtin.kv_state_popn")
+    f_begin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
+    f_end_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward")
+    f_get = tvm.get_global_func("vm.builtin.rnn_state_get")
+    f_set = tvm.get_global_func("vm.builtin.rnn_state_set")
+    f_debug_get = tvm.get_global_func("vm.builtin.rnn_state_debug_get")
+
+    target = tvm.target.Target("cuda")
+
+    def _build(tir_func):
+        mod = tvm.IRModule({"main": tir_func})
+        with target:
+            mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)  # pylint: 
disable=not-callable
+        f = tvm.build(mod["main"], target=target)
+        return f.entry_func
+
+    _f_tir_gets, _f_tir_sets = [], []
+    for state in states:
+        shape, dtype = state
+        _f_tir_gets.append(_build(rnn_state_get(shape, dtype)))
+        _f_tir_sets.append(_build(rnn_state_set(shape, dtype)))
+
+    f_tir_gets = _f_tir_gets
+    f_tir_sets = _f_tir_sets
+
+
+def create_rnn_state():
+    f_create = tvm.get_global_func("vm.builtin.rnn_state_create")
+    init_values = [tvm.nd.array(np_zero, device=device), tvm.nd.array(np_one, 
device=device)]
+    return f_create(num_layers, reserved_nseq, max_history, f_tir_gets, 
f_tir_sets, init_values)
+
+
[email protected]
+def rnn_state():
+    set_global_func()
+    return create_rnn_state()
+
+
+def verify_state(state, seq_ids, expected_values):
+    layer_id = 0
+    for seq_id in seq_ids:
+        for state_id, expected_value in enumerate(expected_values[seq_id]):
+            state_value = f_debug_get(state, layer_id, state_id, seq_id)
+            tvm.testing.assert_allclose(state_value.numpy(), expected_value)
+
+
[email protected]_cuda
+def test_rnn_state_get(rnn_state):  # pylint: disable=redefined-outer-name
+    state = rnn_state
+    f_clear(state)
+    f_add_sequence(state, 0)
+    f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+    tvm_nd_0 = tvm.nd.array(np.empty((1, 16, 16), "float16"), device=device)
+    tvm_nd_1 = tvm.nd.array(np.empty((1, 32, 32), "float32"), device=device)
+    f_get(state, 0, 0, tvm_nd_0)
+    f_get(state, 0, 1, tvm_nd_1)
+    f_end_forward(state)
+    tvm.testing.assert_allclose(tvm_nd_0.numpy(), np.zeros((1, 16, 16), 
"float16"))
+    tvm.testing.assert_allclose(tvm_nd_1.numpy(), np.ones((1, 32, 32), 
"float32"))
+
+
[email protected]_cuda
+def test_rnn_state_set(rnn_state):  # pylint: disable=redefined-outer-name
+    state = rnn_state
+    f_clear(state)
+    for seq_id in range(3):
+        f_add_sequence(state, seq_id)
+    f_begin_forward(state, ShapeTuple([0, 2]), ShapeTuple([1, 1]))
+
+    f_set(state, 0, 0, tvm.nd.array(np.full((2, 16, 16), 2.0, "float16"), 
device=device))
+    f_set(state, 0, 1, tvm.nd.array(np.full((2, 32, 32), 3.0, "float32"), 
device=device))
+    f_end_forward(state)
+
+    expected_values = [[np_two, np_three], [np_zero, np_one], [np_two, 
np_three]]
+    verify_state(state, [0, 1, 2], expected_values)
+
+
[email protected]_cuda
+def test_rnn_state_popn(rnn_state):  # pylint: disable=redefined-outer-name
+    state = rnn_state
+    f_clear(state)
+
+    f_add_sequence(state, 0)
+    f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+    f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device))
+    f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32), 
device=device))
+    f_end_forward(state)
+
+    verify_state(state, [0], [[np_two, np_three]])
+    f_popn(state, 0, 1)
+    verify_state(state, [0], [[np_zero, np_one]])
+    with pytest.raises(tvm.error.TVMError):
+        f_popn(state, 0, 1)  # no available history to pop
+
+
[email protected]_cuda
+def test_rnn_state_fork_sequence(rnn_state):  # pylint: 
disable=redefined-outer-name
+    state = rnn_state
+    f_clear(state)
+
+    f_add_sequence(state, 0)
+    f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1]))
+    f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device))
+    f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32), 
device=device))
+    f_end_forward(state)
+    f_fork_sequence(state, 0, 1)
+    verify_state(state, [0, 1], [[np_two, np_three], [np_two, np_three]])
+    # Verify popn for the forked sequence
+    f_popn(state, 1, 1)
+    verify_state(state, [0, 1], [[np_two, np_three], [np_zero, np_one]])
+
+
+def rnn_state_get(
+    shape: Sequence[int],
+    dtype: str,
+):
+    # fmt: off
+    @T.prim_func
+    def _rnn_state_get(
+        var_storage: T.handle,
+        var_seq_slot_ids: T.handle,
+        var_history_slot_ids: T.handle,
+        var_output: T.handle,
+    ):
+        batch_size = T.int32(is_size_var=True)
+
+        storage = T.match_buffer(var_storage, (reserved_nseq, max_history, 
*shape), dtype)
+        seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), "int32")
+        history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), 
"int32")
+        output = T.match_buffer(var_output, (batch_size, *shape), dtype)
+
+        for i in range(batch_size):
+            for s in T.grid(*shape):
+                with T.block("copy"):
+                    vi, *vs = T.axis.remap("S" * (len(shape) + 1), [i, *s])
+                    seq_id: T.int32 = seq_slot_ids[vi]
+                    history_id: T.int32 = history_slot_ids[vi]
+                    # The following line is equivalent to:
+                    # `output[vi, *vs] = storage[seq_id, history_id, *vs]`
+                    # However, unpacking operator in subscript requires Python 
3.11 or newer
+                    T.buffer_store(
+                        output, T.BufferLoad(storage, [seq_id, history_id, 
*vs]), [vi, *vs]
+                    )
+    # fmt: on
+    return _rnn_state_get
+
+
+def rnn_state_set(
+    shape: Sequence[Union[int, tir.Var]],
+    dtype: str,
+):
+    # fmt: off
+    @T.prim_func
+    def _rnn_state_set(
+        var_storage: T.handle,
+        var_seq_slot_ids: T.handle,
+        var_history_slot_ids: T.handle,
+        var_data: T.handle,
+    ):
+        batch_size = T.int32(is_size_var=True)
+
+        storage = T.match_buffer(var_storage, (reserved_nseq, max_history, 
*shape), dtype)
+        seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), "int32")
+        history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), 
"int32")
+        data = T.match_buffer(var_data, (batch_size, *shape), dtype)
+
+        for i in range(batch_size):
+            for s in T.grid(*shape):
+                with T.block("copy"):
+                    vi, *vs = T.axis.remap("S" * (len(shape) + 1), [i, *s])
+                    seq_id: T.int32 = seq_slot_ids[vi]
+                    history_id: T.int32 = (history_slot_ids[vi] + 1) % T.cast(
+                        max_history, "int32"
+                    )
+                    # The following line is equivalent to:
+                    # `storage[seq_id, history_id, *vs] = data[vi, *vs]`
+                    # However, unpacking operator in subscript requires Python 
3.11 or newer
+                    T.buffer_store(
+                        storage, T.BufferLoad(data, [vi, *vs]), [seq_id, 
history_id, *vs]
+                    )
+
+    # fmt: on
+
+    return _rnn_state_set
+
+
+if __name__ == "__main__":
+    set_global_func()
+    rnn_state = create_rnn_state()
+    test_rnn_state_get(rnn_state)
+    test_rnn_state_set(rnn_state)
+    test_rnn_state_popn(rnn_state)
+    test_rnn_state_fork_sequence(rnn_state)

Reply via email to