You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2024/02/01 20:54:58 UTC
(tvm) branch main updated: [Relax][Web] Add ApplyPresenceAndRequencyPenalty (#16504)
This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5c45ae8e53 [Relax][Web] Add ApplyPresenceAndRequencyPenalty (#16504)
5c45ae8e53 is described below
commit 5c45ae8e53f4f2db1507d1cfe00656c7437c3d4c
Author: Charlie Ruan <53...@users.noreply.github.com>
AuthorDate: Thu Feb 1 15:54:52 2024 -0500
[Relax][Web] Add ApplyPresenceAndRequencyPenalty (#16504)
This PR adds `ApplyPresenceAndFrequencyPenalty()` to
`lm_support.cc` and exposes it to Web runtime.
This is essentially the same as `applyRepetitionPenalty` except we
follow a different way of penalizing repeating tokens, following
https://platform.openai.com/docs/guides/text-generation/frequency-and-presence-penalties.
Tested end-to-end with WebLLM.
---
src/runtime/relax_vm/lm_support.cc | 36 ++++++++++++++++++++++++++++++++++++
web/src/runtime.ts | 24 ++++++++++++++++++++++++
2 files changed, 60 insertions(+)
diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc
index 706e2c3d5f..ecaacb7770 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -516,6 +516,42 @@ void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) {
TVM_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty").set_body_typed(ApplyRepetitionPenalty);
+/*!
+ * \brief Apply presence and frequency penalty. This is an inplace operation.
+ * \param logits The input logits before penalty.
+ * \param token_ids The appeared token ids.
+ * \param token_freqs The number of times each token has appeared since last PrefillStep.
+ * token_freqs[i] is the frequency of token_ids[i], for all i. And all token_freqs should be >= 1.
+ * \param presence_penalty The penalty factor, applied if a token appeared in an one-off manner.
+ * \param frequency_penalty The penalty factor, contributes more the more frequent a token appears.
+ */
+void ApplyPresenceAndFrequencyPenalty(NDArray logits, NDArray token_ids, NDArray token_freqs,
+ double presence_penalty, double frequency_penalty) {
+ // See https://platform.openai.com/docs/guides/text-generation/frequency-and-presence-penalties
+ ICHECK(logits.IsContiguous());
+ ICHECK(token_ids.IsContiguous());
+ ICHECK(token_freqs.IsContiguous());
+ ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
+ ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!";
+ ICHECK(token_freqs.DataType() == DataType::Int(32)) << "token freqs must be int32!";
+ ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
+ ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!";
+ ICHECK(token_freqs->device.device_type == kDLCPU) << "token_ids device must be CPU!";
+
+ float* logits_raw_data = static_cast<float*>(logits->data);
+ int* token_ids_data = static_cast<int*>(token_ids->data);
+ int* token_freqs_data = static_cast<int*>(token_freqs->data);
+ size_t num_token_ids = token_ids->shape[token_ids->ndim - 1];
+ for (size_t i = 0; i < num_token_ids; ++i) {
+ int token_id = token_ids_data[i];
+ int token_freq = token_freqs_data[i];
+ logits_raw_data[token_id] -= (token_freq * frequency_penalty + presence_penalty);
+ }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.apply_presence_and_frequency_penalty")
+ .set_body_typed(ApplyPresenceAndFrequencyPenalty);
+
// This is an inplace operation.
void ApplySoftmaxWithTemperature(NDArray logits, double temperature) {
ICHECK(logits.IsContiguous());
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 4c56005261..cf2a6069e4 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -158,6 +158,7 @@ class RuntimeContext implements Disposable {
ndarrayCreateView: PackedFunc;
sampleTopPFromLogits: PackedFunc;
applyRepetitionPenalty: PackedFunc;
+ applyPresenceAndFrequencyPenalty: PackedFunc;
applySoftmaxWithTemperature: PackedFunc;
private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
@@ -180,6 +181,7 @@ class RuntimeContext implements Disposable {
this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits");
this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty");
+ this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
}
@@ -202,6 +204,7 @@ class RuntimeContext implements Disposable {
this.ndarrayCreateView.dispose();
this.sampleTopPFromLogits.dispose();
this.applyRepetitionPenalty.dispose();
+ this.applyPresenceAndFrequencyPenalty.dispose();
this.applySoftmaxWithTemperature.dispose();
}
@@ -1757,6 +1760,27 @@ export class Instance implements Disposable {
return this.ctx.applyRepetitionPenalty(logits, token_ids, penalty);
}
+ /**
+ * Apply presence and frequency penalty. This is an inplace operation.
+ * @param logits The input logits before penalty.
+ * @param token_ids The appeared token ids.
+ * @param token_freqs The number of times each token has appeared since last PrefillStep.
+ * token_freqs[i] is the frequency of token_ids[i], for all i. And all token_freqs should be >= 1.
+ * @param presence_penalty The penalty factor.
+ * @param frequency_penalty The penalty factor.
+ */
+ applyPresenceAndFrequencyPenalty(
+ logits: NDArray,
+ token_ids: NDArray,
+ token_freqs: NDArray,
+ presence_penalty: number,
+ frequency_penalty: number
+ ) {
+ return this.ctx.applyPresenceAndFrequencyPenalty(
+ logits, token_ids, token_freqs, presence_penalty, frequency_penalty
+ );
+ }
+
/**
* Apply softmax with temperature to the logits.
* @param logits The input logits before softmax w/ temperature.