zhuzilin opened a new pull request #8041:
URL: https://github.com/apache/tvm/pull/8041


   This PR adds a uniform distribution generator using the  threefry PRNG 
introduced in #7083. We would need uniform to develop the training phase 
dropout as the following roadmap:
   
   ```
   uniform -> bernoulli -> dropout
   ```
   
   The algorithm used is basically the same as the one used in jax: using the 
random bits generated from `threefry_generate` as the fraction section of the 
float32 or float64. To be specific, I use the last 23 bits of the random bits 
for float32 and last 52 for float64. There is one different from the jax 
implementation. In jax, they used a bitcast to turn uint into float:
   
   ```python
   # jax implementation
   def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray:
     ...
     bits = _random_bits(key, nbits, shape)
   
     # The strategy here is to randomize only the mantissa bits with an 
exponent of
     # 1 (after applying the bias), then shift and scale to the desired range. 
The
     # bit-level transformation we use relies on Numpy and XLA having 
bit-for-bit
     # equivalent float representations, which might not be true on all 
platforms.
     float_bits = lax.bitwise_or(
         lax.shift_right_logical(bits, np.array(nbits - nmant, 
lax.dtype(bits))),
         np.array(1., dtype).view(_UINT_DTYPES[nbits]))
     floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
     return lax.max(
         minval,
         lax.reshape(floats * (maxval - minval) + minval, shape.positional))
   ```
   
   However, as I haven't found the bitcast in te or topi, I use a divide to 
cast the type, which may be slower:
   
   ```python
       def uniform_scalar(bits):
           bits = bits >> (nbits - nfraction)
           standard_uniform = bits.astype(out_dtype) / float(1 << nfraction)
           return standard_uniform
   ```
   
   Thank you for your time on reviewing this PR. I may not be familiar enough 
with the tvm codebase at the moment, so I'm sorry for breaking any conventions 
in the community and I'd love to fix them :).
   
   Gently ping @tqchen @altanh @tkonolige 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to