This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 40fab65be3 [Unity] Update LM Sample builtins (#14793)
40fab65be3 is described below

commit 40fab65be3ec1ce126e94a1c3f8be778a5764feb
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat May 6 19:47:26 2023 -0400

    [Unity] Update LM Sample builtins (#14793)
    
    This PR updates the LM sample routine builtins.
    Make the sample from logits aware of temperature.
    Add sample from prob that samples from probablity
    distribution after softmax scaling.
---
 src/runtime/relax_vm/lm_support.cc | 92 +++++++++++++++++++++++++++++++-------
 1 file changed, 77 insertions(+), 15 deletions(-)

diff --git a/src/runtime/relax_vm/lm_support.cc 
b/src/runtime/relax_vm/lm_support.cc
index 8b867cc602..8f7e8ebdf9 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -214,34 +214,41 @@ int SampleTopPFromLogits(NDArray logits, double 
temperature, double top_p, doubl
   for (size_t i = 0; i < data.size(); ++i) {
     data[i] = std::make_pair(plogits[i], static_cast<int>(i));
   }
-  // sort by logits from smallest to largest
-  std::sort(data.begin(), data.end());
-  float max_value = data.back().first;
+
+  auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, 
int>& rhs) {
+    return lhs.first > rhs.first;
+  };
+  // sort by logits from largest to smallest
+  std::sort(data.begin(), data.end(), fcmp);
+
   // argmax
   if (temperature < 1e-6f) {
-    return data.back().second;
+    return data[0].second;
   }
-  // compute expf
-  float sum = 0.0f;
-  for (size_t i = 0; i < data.size(); ++i) {
-    data[i].first = expf(data[i].first - max_value);
-    sum += data[i].first;
+
+  // compute expf scaled by temp
+  float sum = 0.0f, logit_scale = 1.0f / temperature;
+  float max_value = data[0].first;
+  for (auto it = data.begin(); it != data.end(); ++it) {
+    it->first = expf((it->first - max_value) * logit_scale);
+    sum += it->first;
   }
+
   // do a cumsum in order of data
   float cum_sum_prob = 0.0f;
   float top_p_sum = 0.0f;
-  for (auto rit = data.rbegin(); rit != data.rend(); ++rit) {
-    float prob = rit->first / sum;
+  for (auto it = data.begin(); it != data.end(); ++it) {
+    float prob = it->first / sum;
     if (cum_sum_prob < top_p) {
       top_p_sum += prob;
     }
     cum_sum_prob += prob;
-    rit->first = cum_sum_prob;
+    it->first = cum_sum_prob;
   }
   // pick a number based on random in (0, 1)
-  for (auto rit = data.rbegin(); rit != data.rend(); ++rit) {
-    if (uniform_sample < rit->first / top_p_sum) {
-      return rit->second;
+  for (auto it = data.begin(); it != data.end(); ++it) {
+    if (uniform_sample < it->first / top_p_sum) {
+      return it->second;
     }
   }
   ICHECK_LE(uniform_sample, data[0].first);
@@ -250,6 +257,61 @@ int SampleTopPFromLogits(NDArray logits, double 
temperature, double top_p, doubl
 
 
TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits);
 
+int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) {
+  ICHECK(prob.IsContiguous());
+  ICHECK(prob.DataType() == DataType::Float(32));
+
+  if (prob->device.device_type != kDLCPU) {
+    prob = prob.CopyTo(DLDevice{kDLCPU, 0});
+  }
+
+  ICHECK(prob->device.device_type == kDLCPU);
+
+  for (int i = 0; i < prob->ndim - 1; ++i) {
+    ICHECK_EQ(prob->shape[i], 1) << "The leading dimensions of logits must be 
1";
+  }
+
+  std::vector<std::pair<float, int>> data;
+  data.resize(prob->shape[prob->ndim - 1]);
+  const float* p_prob = static_cast<float*>(prob->data);
+  for (size_t i = 0; i < data.size(); ++i) {
+    data[i] = std::make_pair(p_prob[i], static_cast<int>(i));
+  }
+
+  auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, 
int>& rhs) {
+    return lhs.first > rhs.first;
+  };
+
+  // sort by logits from largest to smallest
+  std::sort(data.begin(), data.end(), fcmp);
+
+  if (top_p < 1e-6f) {
+    return data.begin()->second;
+  }
+
+  // do a cumsum in order of data
+  float cum_sum_prob = 0.0f;
+  float top_p_sum = 0.0f;
+  for (auto it = data.begin(); it != data.end(); ++it) {
+    float prob = it->first;
+    if (cum_sum_prob < top_p) {
+      top_p_sum += prob;
+    }
+    cum_sum_prob += prob;
+    it->first = cum_sum_prob;
+  }
+  // pick a number based on random in (0, 1)
+  for (auto it = data.begin(); it != data.end(); ++it) {
+    if (uniform_sample < it->first / top_p_sum) {
+      return it->second;
+    }
+  }
+  ICHECK_LE(uniform_sample, data[0].first);
+  return data[0].second;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb);
+
 }  // namespace relax_vm
 }  // namespace runtime
 }  // namespace tvm

Reply via email to