i'm starting vectorizing the operations in my for loop to calculate fetches
into pages around sparse file holes, it can be done by rote despite control
flow and stuff. interesting that vector operations and loops can be transformed
between each other.
it seems possible to be workable, if i can vectorize this big long function
correctly, then it would condense the millions of scalar requests into instead
larger page requests that surround groups of them.
one thing that's missing is consolidating adjacent pages, but this is also
reasonable to vectorize
tails = (offset_lengths[:,0] +
offset_lengths[:,1]).clamp(max=len(self.mmap))
aligned_offsets = offset_lengths[:,0] // self.blksize;
aligned_offsets *= self.blksize
aligned_tails = (tails - 1); aligned_tails //= self.blksize;
aligned_tails += 1; aligned_tails *= self.blksize; torch.clamp(aligned_tails,
max=self.size(), out=aligned_tails)
that's the start of my vectorizing read_many which is just started. called by
this code:
def fetch_scalars(self, offsets, progress='', validate_usage=True):
if validate_usage:
bytes_avail_cpu = psutil.virtual_memory().available *
self.safeslice.statedict.usage_frac
assert len(offsets) * self.dtype.itemsize < bytes_avail_cpu
readsize = self.safeslice.tensor.element_size()
offset_lengths = torch.empty([offsets.numel(), 2], dtype=int)
offset_lengths[:,0] = offsets.view(-1)
offset_lengths[:,0] += self.storage_offset()
offset_lengths[:,0] *= readsize
offset_lengths[:,0] += self.safeslice.offset
offset_lengths[:,1] = readsize
datas = b''.join(self.safeslice.fetcher.read_many(offset_lengths,
progress=progress, validate_sorted=False))
return torch.frombuffer(
datas,
dtype=self.safeslice.tensor.dtype,
count=len(offsets),
).view(offsets.shape).to(self.device, self.dtype)
which is called by something like this code in F_linear:
major_stride, minor_stride = weight.stride()
offsets = torch.add(*torch.meshgrid(
torch.arange(weight.shape[0]) * major_stride,
top_k_indices * minor_stride,
indexing = 'ij'
))
... [below simplified removing local unfixed bugs to try to match new code here]
product = torch.matmul(
input,
weight.fetch_scalars(offsets, progress=name,
validate_usage=False).T
)