This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 90edf76716 [Unity][LLM] Add NaN checks during sampling for better
error reporting (#16141)
90edf76716 is described below
commit 90edf767167224a5367ff24eba894fe3d5e502a4
Author: Siyuan Feng <[email protected]>
AuthorDate: Sat Nov 18 01:36:28 2023 +0800
[Unity][LLM] Add NaN checks during sampling for better error reporting
(#16141)
The current error message would be confusing:
```
mlc-llm/3rdparty/tvm/src/runtime/relax_vm/lm_support.cc:421: InternalError:
Check failed: sampled_index >= 0 (-1 vs. 0)
```
But most of the case is cause by NaN error. This PR improves the error
message
---
src/runtime/relax_vm/lm_support.cc | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/src/runtime/relax_vm/lm_support.cc
b/src/runtime/relax_vm/lm_support.cc
index fbff8ff029..6301245dac 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -452,6 +452,10 @@ int SampleTopPFromProb(NDArray prob, double top_p, double
uniform_sample) {
return data[data.size() - 1].second;
};
+ auto is_all_nan = [&]() -> bool {
+ return std::all_of(p_prob, p_prob + ndata, [](float x) { return
std::isnan(x); });
+ };
+
if (top_p < 1) {
// sample through cutoff by a number
// by pigeonhole principle we will get at most 1024 elements
@@ -463,7 +467,11 @@ int SampleTopPFromProb(NDArray prob, double top_p, double
uniform_sample) {
// fallback via full prob, rare case
data.reserve(ndata);
int64_t sampled_index = sample_top_p_with_filter(0.0f);
- ICHECK_GE(sampled_index, 0);
+ if (sampled_index < 0 && is_all_nan()) {
+ LOG(FATAL) << "The output probabilities are all NaNs, can not sample from
it";
+ } else if (sampled_index < 0) {
+ LOG(FATAL) << "Cannot sample from the given probability distribution due
to unknown reason";
+ }
return sampled_index;
}