On Fri, Nov 17, 2023, at 14:28, Stefan van der Walt wrote:
> Attached is a script that implements this solution.

And the version with set duplicates checking.

Stéfan
import random
import functools
import itertools
import operator

import numpy as np


def cumulative_prod(arr):
    return list(itertools.accumulate(arr, func=operator.mul))


def unravel_index(x, dims):
    dim_prod = cumulative_prod([1] + list(dims)[:0:-1])[::-1]
    return [list((ix // dim_prod[i]) % dims[i] for i in range(len(dims))) for ix in x]


# From Robert Kern's comment at
# https://github.com/numpy/numpy/issues/24458#issuecomment-1685022258
class PythonRandomInterface(random.Random):
    def __init__(self, rng):
        self._rng = rng

    def getrandbits(self, k):
        """getrandbits(k) -> x.  Generates an int with k random bits."""
        if k < 0:
            raise ValueError('number of bits must be non-negative')
        numbytes = (k + 7) // 8                       # bits / 8 and rounded up
        x = int.from_bytes(self._rng.bytes(numbytes), 'big')
        return x >> (numbytes * 8 - k)                # trim excess bits

    def indices(self, shape, size=1):
        D = functools.reduce(lambda x, y: x * y, dims)
        indices = set()
        while len(indices) < size:
            indices.add(pri.randint(0, D))
        return unravel_index(indices, shape)


rng = np.random.default_rng()
pri = PythonRandomInterface(rng)

dims = (500, 400, 30, 15, 20, 800, 900, 2000, 800)
k = 5

print(pri.indices(dims, size=k))
_______________________________________________
NumPy-Discussion mailing list -- numpy-discussion@python.org
To unsubscribe send an email to numpy-discussion-le...@python.org
https://mail.python.org/mailman3/lists/numpy-discussion.python.org/
Member address: arch...@mail-archive.com

Reply via email to