i'm starting vectorizing the operations in my for loop to calculate fetches 
into pages around sparse file holes, it can be done by rote despite control 
flow and stuff. interesting that vector operations and loops can be transformed 
between each other.

it seems possible to be workable, if i can vectorize this big long function 
correctly, then it would condense the millions of scalar requests into instead 
larger page requests that surround groups of them.

one thing that's missing is consolidating adjacent pages, but this is also 
reasonable to vectorize

            tails = (offset_lengths[:,0] + 
offset_lengths[:,1]).clamp(max=len(self.mmap))
            aligned_offsets = offset_lengths[:,0] // self.blksize; 
aligned_offsets *= self.blksize
            aligned_tails = (tails - 1); aligned_tails //= self.blksize; 
aligned_tails += 1; aligned_tails *= self.blksize; torch.clamp(aligned_tails, 
max=self.size(), out=aligned_tails)

that's the start of my vectorizing read_many which is just started. called by 
this code:

    def fetch_scalars(self, offsets, progress='', validate_usage=True):
        if validate_usage:
            bytes_avail_cpu = psutil.virtual_memory().available * 
self.safeslice.statedict.usage_frac
            assert len(offsets) * self.dtype.itemsize < bytes_avail_cpu
        readsize = self.safeslice.tensor.element_size()
        offset_lengths = torch.empty([offsets.numel(), 2], dtype=int)
        offset_lengths[:,0] = offsets.view(-1)
        offset_lengths[:,0] += self.storage_offset()
        offset_lengths[:,0] *= readsize
        offset_lengths[:,0] += self.safeslice.offset
        offset_lengths[:,1] = readsize
        datas = b''.join(self.safeslice.fetcher.read_many(offset_lengths, 
progress=progress, validate_sorted=False))
        return torch.frombuffer(
            datas,
            dtype=self.safeslice.tensor.dtype,
            count=len(offsets),
        ).view(offsets.shape).to(self.device, self.dtype)

which is called by something like this code in F_linear:

        major_stride, minor_stride = weight.stride()
        offsets = torch.add(*torch.meshgrid(
            torch.arange(weight.shape[0]) * major_stride,
            top_k_indices * minor_stride,
            indexing = 'ij'
        ))
... [below simplified removing local unfixed bugs to try to match new code here]
            product = torch.matmul(
                input,
                weight.fetch_scalars(offsets, progress=name, 
validate_usage=False).T
            )

Reply via email to