zhuzilin opened a new pull request #8041:
URL: https://github.com/apache/tvm/pull/8041
This PR adds a uniform distribution generator using the threefry PRNG
introduced in #7083. We would need uniform to develop the training phase
dropout as the following roadmap:
```
uniform -> bernoulli -> dropout
```
The algorithm used is basically the same as the one used in jax: using the
random bits generated from `threefry_generate` as the fraction section of the
float32 or float64. To be specific, I use the last 23 bits of the random bits
for float32 and last 52 for float64. There is one different from the jax
implementation. In jax, they used a bitcast to turn uint into float:
```python
# jax implementation
def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray:
...
bits = _random_bits(key, nbits, shape)
# The strategy here is to randomize only the mantissa bits with an
exponent of
# 1 (after applying the bias), then shift and scale to the desired range.
The
# bit-level transformation we use relies on Numpy and XLA having
bit-for-bit
# equivalent float representations, which might not be true on all
platforms.
float_bits = lax.bitwise_or(
lax.shift_right_logical(bits, np.array(nbits - nmant,
lax.dtype(bits))),
np.array(1., dtype).view(_UINT_DTYPES[nbits]))
floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
return lax.max(
minval,
lax.reshape(floats * (maxval - minval) + minval, shape.positional))
```
However, as I haven't found the bitcast in te or topi, I use a divide to
cast the type, which may be slower:
```python
def uniform_scalar(bits):
bits = bits >> (nbits - nfraction)
standard_uniform = bits.astype(out_dtype) / float(1 << nfraction)
return standard_uniform
```
Thank you for your time on reviewing this PR. I may not be familiar enough
with the tvm codebase at the moment, so I'm sorry for breaking any conventions
in the community and I'd love to fix them :).
Gently ping @tqchen @altanh @tkonolige
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]