This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e6fad30 Efficient MXNet sampling in the multinomial distribution
(#15311)
e6fad30 is described below
commit e6fad30e45e6ec0ddef5c18093e8163cd2a7c62c
Author: Zixuan Wei <[email protected]>
AuthorDate: Sat Jun 22 14:26:27 2019 +0800
Efficient MXNet sampling in the multinomial distribution (#15311)
* Effective multinomial
* Meaningful uniform data pointer as input
* Remove beginning Zeros from CDFs
* Double precision for accumulated var
---
src/operator/random/sample_multinomial_op.h | 42 ++++++++++++++++-------------
1 file changed, 24 insertions(+), 18 deletions(-)
diff --git a/src/operator/random/sample_multinomial_op.h
b/src/operator/random/sample_multinomial_op.h
index 377df4f..5a0b9bb 100644
--- a/src/operator/random/sample_multinomial_op.h
+++ b/src/operator/random/sample_multinomial_op.h
@@ -122,25 +122,29 @@ inline bool SampleMultinomialOpType(const
nnvm::NodeAttrs& attrs,
struct SampleMultinomialKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, index_t K, index_t M,
- DType* dist, float* uniform, IType* out,
- DType* prob) {
+ DType* dist, float* uniform, float*
cum_table,
+ IType* out, DType* prob) {
+ double acc = 0.0;
+ // CDF table
+ for (index_t c = 0; c < K; ++c) {
+ acc += dist[i*K + c];
+ cum_table[i*K + c] = static_cast<float>(acc);
+ }
for (index_t j = 0; j < M; ++j) {
+ index_t left = 0, right = K;
+ index_t middle = left + (right - left) / 2;
DType loc = static_cast<DType>(uniform[i*M + j]);
- DType acc = 0;
- bool found = false;
- for (index_t k = 0; k < K; ++k) {
- acc += dist[i*K + k];
- if (acc > loc) {
- found = true;
- out[i*M + j] = static_cast<IType>(k);
- if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + k]);
- break;
+ while (right - left > 0) {
+ middle = left + (right - left) / 2;
+ DType cum_prob = cum_table[i*K + middle];
+ if (cum_prob < loc) {
+ left = middle + 1;
+ } else {
+ right = middle;
}
}
- if (!found) {
- out[i*M + j] = static_cast<IType>(K-1);
- if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + K - 1]);
- }
+ out[i*M + j] = static_cast<IType>(left);
+ if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + left]);
}
}
};
@@ -163,12 +167,14 @@ void SampleMultinomialForward(const nnvm::NodeAttrs&
attrs,
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
- Tensor<xpu, 1, float> uniform =
- ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M), s);
+ Tensor<xpu, 1, float> workspace =
+ ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M + N*K), s);
+ Tensor<xpu, 1, float> uniform(workspace.dptr_, Shape1(N*M));
prnd->SampleUniform(&uniform, 0, 1);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
Kernel<SampleMultinomialKernel, xpu>::Launch(
- s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_,
outputs[0].dptr<IType>(),
+ s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, workspace.dptr_ +
N*M,
+ outputs[0].dptr<IType>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
});
});