You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/06/12 19:12:26 UTC
[tvm] branch unity updated: [Unity] Optimize SampleTopPFromProb (#15072)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f61015c20c [Unity] Optimize SampleTopPFromProb (#15072)
f61015c20c is described below
commit f61015c20cfbcef5f3a466a12a658ca40b221f47
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Mon Jun 12 15:12:15 2023 -0400
[Unity] Optimize SampleTopPFromProb (#15072)
This PR optimizes sample top p from prob to reduce its running time
to about 10x or more on M1. The main observation is that we don't
need full sort and can just select top p based on a filtered cap.
---
src/runtime/relax_vm/lm_support.cc | 94 +++++++++++++++++++++++++-------------
1 file changed, 63 insertions(+), 31 deletions(-)
diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc
index 9b14161e67..bdee444608 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -351,43 +351,75 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) {
ICHECK_EQ(prob->shape[i], 1) << "The leading dimensions of logits must be 1";
}
+ // Key observation: when we are doing top_p sampling
+ // usually we only need to preserve some of the elements with
+ // high probablities before we do sort
std::vector<std::pair<float, int>> data;
- data.resize(prob->shape[prob->ndim - 1]);
+ int64_t ndata = prob->shape[prob->ndim - 1];
const float* p_prob = static_cast<float*>(prob->data);
- for (size_t i = 0; i < data.size(); ++i) {
- data[i] = std::make_pair(p_prob[i], static_cast<int>(i));
- }
-
- auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, int>& rhs) {
- return lhs.first > rhs.first;
- };
-
- // sort by logits from largest to smallest
- std::sort(data.begin(), data.end(), fcmp);
- if (top_p < 1e-6f) {
- return data.begin()->second;
- }
-
- // do a cumsum in order of data
- float cum_sum_prob = 0.0f;
- float top_p_sum = 0.0f;
- for (auto it = data.begin(); it != data.end(); ++it) {
- float prob = it->first;
- if (cum_sum_prob < top_p) {
- top_p_sum += prob;
+ auto sample_top_p_with_filter = [&](float cuttoff) -> int64_t {
+ data.clear();
+ // filter the data with cuttoff
+ for (int64_t i = 0; i < ndata; ++i) {
+ if (p_prob[i] >= cuttoff) {
+ data.emplace_back(std::make_pair(p_prob[i], static_cast<int>(i)));
+ }
}
- cum_sum_prob += prob;
- it->first = cum_sum_prob;
- }
- // pick a number based on random in (0, 1)
- for (auto it = data.begin(); it != data.end(); ++it) {
- if (uniform_sample < it->first / top_p_sum) {
- return it->second;
+ if (data.size() == 0) return -1;
+ auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, int>& rhs) {
+ return lhs.first > rhs.first;
+ };
+ std::sort(data.begin(), data.end(), fcmp);
+
+ // short cut, if we know that
+ // uniform sample < p[0] / top_p
+ // we know that unform_sample < p[0] / top_p_sum
+ // because top_p_sum gaurantees to be smaller than top_p
+ // so we can simply return the argmax sample
+ // without computing anything
+ if (uniform_sample < data[0].first / top_p) return data[0].second;
+
+ // compute top_p_sum
+ float cum_sum_prob = 0.0f;
+ float top_p_sum = 0.0f;
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ float prob = it->first;
+ if (cum_sum_prob < top_p) {
+ top_p_sum += prob;
+ } else {
+ // we get to the right cutoff pt
+ break;
+ }
+ cum_sum_prob += prob;
+ it->first = cum_sum_prob;
+ }
+ // we find that the current total sum by the given cutoff
+ // is not sufficient to cover everything
+ // this means we might need to retry a smaller cutoff pt.
+ if (cum_sum_prob < top_p && cuttoff != 0.0f) return -1;
+
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ if (uniform_sample < it->first / top_p_sum) {
+ return it->second;
+ }
}
+ return data[data.size() - 1].second;
+ };
+
+ if (top_p < 1) {
+ // sample through cutoff by a number
+ // by pigeonhole principle we will get at most 1024 elements
+ // usually it is much less by applying this filtering(order of 10 - 20)
+ data.reserve(128);
+ int64_t sampled_index = sample_top_p_with_filter(top_p / 1024);
+ if (sampled_index >= 0) return sampled_index;
}
- ICHECK_LE(uniform_sample, data[0].first);
- return data[0].second;
+ // fallback via full prob, rare case
+ data.reserve(ndata);
+ int64_t sampled_index = sample_top_p_with_filter(0.0f);
+ ICHECK_GE(sampled_index, 0);
+ return sampled_index;
}
TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb);