You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2024/02/01 20:54:58 UTC

(tvm) branch main updated: [Relax][Web] Add ApplyPresenceAndRequencyPenalty (#16504)

This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c45ae8e53 [Relax][Web] Add ApplyPresenceAndRequencyPenalty (#16504)
5c45ae8e53 is described below

commit 5c45ae8e53f4f2db1507d1cfe00656c7437c3d4c
Author: Charlie Ruan <53...@users.noreply.github.com>
AuthorDate: Thu Feb 1 15:54:52 2024 -0500

    [Relax][Web] Add ApplyPresenceAndRequencyPenalty (#16504)
    
    This PR adds `ApplyPresenceAndFrequencyPenalty()` to
    `lm_support.cc` and exposes it to Web runtime.
    
    This is essentially the same as `applyRepetitionPenalty` except we
    follow a different way of penalizing repeating tokens, following
    https://platform.openai.com/docs/guides/text-generation/frequency-and-presence-penalties.
    
    Tested end-to-end with WebLLM.
---
 src/runtime/relax_vm/lm_support.cc | 36 ++++++++++++++++++++++++++++++++++++
 web/src/runtime.ts                 | 24 ++++++++++++++++++++++++
 2 files changed, 60 insertions(+)

diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc
index 706e2c3d5f..ecaacb7770 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -516,6 +516,42 @@ void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) {
 
 TVM_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty").set_body_typed(ApplyRepetitionPenalty);
 
+/*!
+ * \brief Apply presence and frequency penalty. This is an inplace operation.
+ * \param logits The input logits before penalty.
+ * \param token_ids The appeared token ids.
+ * \param token_freqs The number of times each token has appeared since last PrefillStep.
+ * token_freqs[i] is the frequency of token_ids[i], for all i. And all token_freqs should be >= 1.
+ * \param presence_penalty The penalty factor, applied if a token appeared in an one-off manner.
+ * \param frequency_penalty The penalty factor, contributes more the more frequent a token appears.
+ */
+void ApplyPresenceAndFrequencyPenalty(NDArray logits, NDArray token_ids, NDArray token_freqs,
+                                      double presence_penalty, double frequency_penalty) {
+  // See https://platform.openai.com/docs/guides/text-generation/frequency-and-presence-penalties
+  ICHECK(logits.IsContiguous());
+  ICHECK(token_ids.IsContiguous());
+  ICHECK(token_freqs.IsContiguous());
+  ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
+  ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!";
+  ICHECK(token_freqs.DataType() == DataType::Int(32)) << "token freqs must be int32!";
+  ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
+  ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!";
+  ICHECK(token_freqs->device.device_type == kDLCPU) << "token_ids device must be CPU!";
+
+  float* logits_raw_data = static_cast<float*>(logits->data);
+  int* token_ids_data = static_cast<int*>(token_ids->data);
+  int* token_freqs_data = static_cast<int*>(token_freqs->data);
+  size_t num_token_ids = token_ids->shape[token_ids->ndim - 1];
+  for (size_t i = 0; i < num_token_ids; ++i) {
+    int token_id = token_ids_data[i];
+    int token_freq = token_freqs_data[i];
+    logits_raw_data[token_id] -= (token_freq * frequency_penalty + presence_penalty);
+  }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.apply_presence_and_frequency_penalty")
+    .set_body_typed(ApplyPresenceAndFrequencyPenalty);
+
 // This is an inplace operation.
 void ApplySoftmaxWithTemperature(NDArray logits, double temperature) {
   ICHECK(logits.IsContiguous());
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 4c56005261..cf2a6069e4 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -158,6 +158,7 @@ class RuntimeContext implements Disposable {
   ndarrayCreateView: PackedFunc;
   sampleTopPFromLogits: PackedFunc;
   applyRepetitionPenalty: PackedFunc;
+  applyPresenceAndFrequencyPenalty: PackedFunc;
   applySoftmaxWithTemperature: PackedFunc;
 
   private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
@@ -180,6 +181,7 @@ class RuntimeContext implements Disposable {
     this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
     this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits");
     this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty");
+    this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
     this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
   }
 
@@ -202,6 +204,7 @@ class RuntimeContext implements Disposable {
     this.ndarrayCreateView.dispose();
     this.sampleTopPFromLogits.dispose();
     this.applyRepetitionPenalty.dispose();
+    this.applyPresenceAndFrequencyPenalty.dispose();
     this.applySoftmaxWithTemperature.dispose();
   }
 
@@ -1757,6 +1760,27 @@ export class Instance implements Disposable {
     return this.ctx.applyRepetitionPenalty(logits, token_ids, penalty);
   }
 
+  /**
+   * Apply presence and frequency penalty. This is an inplace operation.
+   * @param logits The input logits before penalty.
+   * @param token_ids The appeared token ids.
+   * @param token_freqs The number of times each token has appeared since last PrefillStep.
+   * token_freqs[i] is the frequency of token_ids[i], for all i. And all token_freqs should be >= 1.
+   * @param presence_penalty The penalty factor.
+   * @param frequency_penalty The penalty factor.
+   */
+  applyPresenceAndFrequencyPenalty(
+    logits: NDArray,
+    token_ids: NDArray,
+    token_freqs: NDArray,
+    presence_penalty: number,
+    frequency_penalty: number
+  ) {
+    return this.ctx.applyPresenceAndFrequencyPenalty(
+      logits, token_ids, token_freqs, presence_penalty, frequency_penalty
+    );
+  }
+
   /**
    * Apply softmax with temperature to the logits.
    * @param logits The input logits before softmax w/ temperature.