jinhongyii commented on code in PR #14478: URL: https://github.com/apache/tvm/pull/14478#discussion_r1156647265
########## src/runtime/relax_vm/attention_kv_cache.cc: ########## @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/runtime/relax_vm/attention_kv_cache.cc + * \brief A simple implementation of inplace attention kv cache for runtime. + * + * This file provides a simple implementation of inplace attention + * kv cache for relax runtime. The main goal here is to help us enable + * auto-regressive decoding quickly in relax. + * + * This is not the only way to support attention kv-cache. + * Our support of attention kv-cache can subject to future + * changes as we build more LM verticals. + * + * We will keep the impact minimum by puting it as a private + * runtime builtin provide as in this file. + * + * We can evolve this implementation as we build more LM verticals. + */ + +#include <tvm/runtime/container/shape_tuple.h> +#include <tvm/runtime/device_api.h> +#include <tvm/runtime/logging.h> +#include <tvm/runtime/memory.h> +#include <tvm/runtime/ndarray.h> +#include <tvm/runtime/relax_vm/vm.h> + +namespace tvm { +namespace runtime { +namespace relax_vm { + +//------------------------------------------- +// We keep the implementation private as +// they may subject to future changes. +// +// Users can interact with it through the +// runtime API function calls +//------------------------------------------- +/*! + * \brief An object representing an attention kv cache. + */ +class AttentionKVCacheObj : public Object { + public: + /*! + * \brief Underlying support data. + */ + NDArray data; + + /*! + * \brief number of slots already filled. + */ + int64_t fill_count{0}; + + /*! + * \brief View all current cached values as one array. + * \param shape The cached values. + */ + NDArray View(const ShapeTuple& shape) { + CHECK_EQ(shape[0], fill_count) << "Requested shape do not match the filled count"; + for (int i = 1; i < this->data->ndim; ++i) { + CHECK_EQ(shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; + } + return data.CreateView(shape, data->dtype); + } + + /*! + * \brief Append value to the cache. + * \param value The value to be appended. + */ + void Append(NDArray value) { Review Comment: I think we need to specify the dim to append because seq_len may not be the first dimension in kv cache ########## src/runtime/relax_vm/attention_kv_cache.cc: ########## @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/runtime/relax_vm/attention_kv_cache.cc + * \brief A simple implementation of inplace attention kv cache for runtime. + * + * This file provides a simple implementation of inplace attention + * kv cache for relax runtime. The main goal here is to help us enable + * auto-regressive decoding quickly in relax. + * + * This is not the only way to support attention kv-cache. + * Our support of attention kv-cache can subject to future + * changes as we build more LM verticals. + * + * We will keep the impact minimum by puting it as a private + * runtime builtin provide as in this file. + * + * We can evolve this implementation as we build more LM verticals. + */ + +#include <tvm/runtime/container/shape_tuple.h> +#include <tvm/runtime/device_api.h> +#include <tvm/runtime/logging.h> +#include <tvm/runtime/memory.h> +#include <tvm/runtime/ndarray.h> +#include <tvm/runtime/relax_vm/vm.h> + +namespace tvm { +namespace runtime { +namespace relax_vm { + +//------------------------------------------- +// We keep the implementation private as +// they may subject to future changes. +// +// Users can interact with it through the +// runtime API function calls +//------------------------------------------- +/*! + * \brief An object representing an attention kv cache. + */ +class AttentionKVCacheObj : public Object { + public: + /*! + * \brief Underlying support data. + */ + NDArray data; + + /*! + * \brief number of slots already filled. + */ + int64_t fill_count{0}; + + /*! + * \brief View all current cached values as one array. + * \param shape The cached values. + */ + NDArray View(const ShapeTuple& shape) { Review Comment: why do we need to pass shape as argument? just for check? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
