You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/05/26 16:55:31 UTC

[tvm] branch unity updated: [Unity] support update KV cache (#14964)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7d78b278d3 [Unity] support update KV cache (#14964)
7d78b278d3 is described below

commit 7d78b278d3d681a9e70d76d6a1e20646475f22cf
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Sat May 27 00:55:24 2023 +0800

    [Unity] support update KV cache (#14964)
    
    This PR adds API `vm.builtin.attention_kv_cache_update ` to support
    RWKV.
---
 src/runtime/relax_vm/lm_support.cc | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc
index 8f7e8ebdf9..2df8f278f4 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -85,6 +85,19 @@ class AttentionKVCacheObj : public Object {
   /** Clear the cache */
   void Clear() { this->fill_count = 0; }
 
+  void Update(NDArray value) {
+    CHECK(data.DataType() == value.DataType()) << "dtype mismatch";
+    CHECK_EQ(value->shape[0], fill_count) << "Requested shape do not match the filled count";
+    ICHECK(data.IsContiguous());
+    ICHECK(value.IsContiguous());
+
+    DLTensor copy_dst = *(data.operator->());
+    copy_dst.byte_offset = 0;
+    copy_dst.shape = value->shape;
+    NDArray::CopyFromTo(value.operator->(), &copy_dst);
+    this->fill_count = value->shape[0];
+  }
+
   /*!
    * \brief Append value to the cache.
    * \param value The value to be appended.
@@ -154,6 +167,13 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create")
     .set_body_typed(AttentionKVCache::Create);
 
+AttentionKVCache AttentionKVCacheUpdate(AttentionKVCache cache, NDArray value) {
+  cache->Update(value);
+  return cache;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update").set_body_typed(AttentionKVCacheUpdate);
+
 AttentionKVCache AttentionKVCacheAppend(AttentionKVCache cache, NDArray value) {
   cache->Append(value);
   return cache;