You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/06/12 19:12:26 UTC

[tvm] branch unity updated: [Unity] Optimize SampleTopPFromProb (#15072)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f61015c20c [Unity] Optimize SampleTopPFromProb (#15072)
f61015c20c is described below

commit f61015c20cfbcef5f3a466a12a658ca40b221f47
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Mon Jun 12 15:12:15 2023 -0400

    [Unity] Optimize SampleTopPFromProb (#15072)
    
    This PR optimizes sample top p from prob to reduce its running time
    to about 10x or more on M1. The main observation is that we don't
    need full sort and can just select top p based on a filtered cap.
---
 src/runtime/relax_vm/lm_support.cc | 94 +++++++++++++++++++++++++-------------
 1 file changed, 63 insertions(+), 31 deletions(-)

diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc
index 9b14161e67..bdee444608 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -351,43 +351,75 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) {
     ICHECK_EQ(prob->shape[i], 1) << "The leading dimensions of logits must be 1";
   }
 
+  // Key observation: when we are doing top_p sampling
+  // usually we only need to preserve some of the elements with
+  // high probablities before we do sort
   std::vector<std::pair<float, int>> data;
-  data.resize(prob->shape[prob->ndim - 1]);
+  int64_t ndata = prob->shape[prob->ndim - 1];
   const float* p_prob = static_cast<float*>(prob->data);
-  for (size_t i = 0; i < data.size(); ++i) {
-    data[i] = std::make_pair(p_prob[i], static_cast<int>(i));
-  }
-
-  auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, int>& rhs) {
-    return lhs.first > rhs.first;
-  };
-
-  // sort by logits from largest to smallest
-  std::sort(data.begin(), data.end(), fcmp);
 
-  if (top_p < 1e-6f) {
-    return data.begin()->second;
-  }
-
-  // do a cumsum in order of data
-  float cum_sum_prob = 0.0f;
-  float top_p_sum = 0.0f;
-  for (auto it = data.begin(); it != data.end(); ++it) {
-    float prob = it->first;
-    if (cum_sum_prob < top_p) {
-      top_p_sum += prob;
+  auto sample_top_p_with_filter = [&](float cuttoff) -> int64_t {
+    data.clear();
+    // filter the data with cuttoff
+    for (int64_t i = 0; i < ndata; ++i) {
+      if (p_prob[i] >= cuttoff) {
+        data.emplace_back(std::make_pair(p_prob[i], static_cast<int>(i)));
+      }
     }
-    cum_sum_prob += prob;
-    it->first = cum_sum_prob;
-  }
-  // pick a number based on random in (0, 1)
-  for (auto it = data.begin(); it != data.end(); ++it) {
-    if (uniform_sample < it->first / top_p_sum) {
-      return it->second;
+    if (data.size() == 0) return -1;
+    auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, int>& rhs) {
+      return lhs.first > rhs.first;
+    };
+    std::sort(data.begin(), data.end(), fcmp);
+
+    // short cut, if we know that
+    // uniform sample < p[0] / top_p
+    // we know that unform_sample < p[0] / top_p_sum
+    // because top_p_sum gaurantees to be smaller than top_p
+    // so we can simply return the argmax sample
+    // without computing anything
+    if (uniform_sample < data[0].first / top_p) return data[0].second;
+
+    // compute top_p_sum
+    float cum_sum_prob = 0.0f;
+    float top_p_sum = 0.0f;
+    for (auto it = data.begin(); it != data.end(); ++it) {
+      float prob = it->first;
+      if (cum_sum_prob < top_p) {
+        top_p_sum += prob;
+      } else {
+        // we get to the right cutoff pt
+        break;
+      }
+      cum_sum_prob += prob;
+      it->first = cum_sum_prob;
+    }
+    // we find that the current total sum by the given cutoff
+    // is not sufficient to cover everything
+    // this means we might need to retry a smaller cutoff pt.
+    if (cum_sum_prob < top_p && cuttoff != 0.0f) return -1;
+
+    for (auto it = data.begin(); it != data.end(); ++it) {
+      if (uniform_sample < it->first / top_p_sum) {
+        return it->second;
+      }
     }
+    return data[data.size() - 1].second;
+  };
+
+  if (top_p < 1) {
+    // sample through cutoff by a number
+    // by pigeonhole principle we will get at most 1024 elements
+    // usually it is much less by applying this filtering(order of 10 - 20)
+    data.reserve(128);
+    int64_t sampled_index = sample_top_p_with_filter(top_p / 1024);
+    if (sampled_index >= 0) return sampled_index;
   }
-  ICHECK_LE(uniform_sample, data[0].first);
-  return data[0].second;
+  // fallback via full prob, rare case
+  data.reserve(ndata);
+  int64_t sampled_index = sample_top_p_with_filter(0.0f);
+  ICHECK_GE(sampled_index, 0);
+  return sampled_index;
 }
 
 TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb);