On Fri, Nov 17, 2023, at 14:28, Stefan van der Walt wrote:
> Attached is a script that implements this solution.
And the version with set duplicates checking.
Stéfan
import random
import functools
import itertools
import operator
import numpy as np
def cumulative_prod(arr):
return list(itertools.accumulate(arr, func=operator.mul))
def unravel_index(x, dims):
dim_prod = cumulative_prod([1] + list(dims)[:0:-1])[::-1]
return [list((ix // dim_prod[i]) % dims[i] for i in range(len(dims))) for ix in x]
# From Robert Kern's comment at
# https://github.com/numpy/numpy/issues/24458#issuecomment-1685022258
class PythonRandomInterface(random.Random):
def __init__(self, rng):
self._rng = rng
def getrandbits(self, k):
"""getrandbits(k) -> x. Generates an int with k random bits."""
if k < 0:
raise ValueError('number of bits must be non-negative')
numbytes = (k + 7) // 8 # bits / 8 and rounded up
x = int.from_bytes(self._rng.bytes(numbytes), 'big')
return x >> (numbytes * 8 - k) # trim excess bits
def indices(self, shape, size=1):
D = functools.reduce(lambda x, y: x * y, dims)
indices = set()
while len(indices) < size:
indices.add(pri.randint(0, D))
return unravel_index(indices, shape)
rng = np.random.default_rng()
pri = PythonRandomInterface(rng)
dims = (500, 400, 30, 15, 20, 800, 900, 2000, 800)
k = 5
print(pri.indices(dims, size=k))
_______________________________________________
NumPy-Discussion mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/numpy-discussion.python.org/
Member address: [email protected]