next step for me is likely journaling out what conditions would cover the
untested code, or what logic error is around having it there
from collections import namedtuple
import bisect

class Chunk:
    def __init__(self, start, end, data, height=0, leaf_count=1, age=0):
        self.start = start
        self.end = end
        self.data = data
        self.height = height
        self.leaf_count = leaf_count
        self.age = age
    def __len__(self):
        return self.end - self.start

class Flush(Chunk):
    class Entry(Chunk):
        def __init__(self, start, end, chunk, path = []):
            super().__init__(start, end, chunk, height=0, leaf_count=1, age=chunk.age)
            self.path = [*path, self.data]
            if type(chunk) is Flush:
                self.leaf_count = 0
                self.height = 0
                for entry in self.flush_entries():
                    self.leaf_count += entry.leaf_count
                    self.height = max(self.height, entry.height + 1)
        def flush_entries(self):
            assert type(self.data) is Flush
            return (
                Flush.Entry(max(entry.start, self.start), min(entry.end, self.end), entry.data, self.path)
                for entry in self.data.data
                if entry.start < self.end and entry.end > self.start
            )
        def chunk_data(self):
            assert type(self.data) is Chunk
            return self.data.data[self.start - self.data.start : self.end - self.data.start]
    def __init__(self, prev_flush = None):
        if prev_flush is not None:
            leaf_count = prev_flush.leaf_count
            self.max_height = leaf_count.bit_length()
            super().__init__(prev_flush.start, prev_flush.end, [], leaf_count=0, age=prev_flush.age+1)
            self.add(prev_flush)
            for entry in prev_flush.data:
                if entry.height >= self.max_height:
                    self.add(entry)
        else:
            super().__init__(None, None, [], height=1, leaf_count=0)
    def add(self, *adjacents):
        adjacents = [
            adjacent if type(adjacent) is Flush.Entry
            else Flush.Entry(adjacent.start, adjacent.end, adjacent)
            for adjacent in adjacents
        ]
        if self.start is None:
            self.start = adjacents[0].start
            self.end = adjacents[-1].end
        else:
            self.start = min(self.start, adjacents[0].start)
            self.end = max(self.end, adjacents[-1].end)

        # first idx with end > start
        start_idx = bisect.bisect_right([entry.end for entry in self.data], adjacents[0].start)
        # first idx with start >= end
        end_idx = bisect.bisect_left([entry.start for entry in self.data], adjacents[-1].end, start_idx)
        replaced = self.data[start_idx:end_idx]
        if len(replaced):
            if replaced[0].start < adjacents[0].start:
                adjacents.insert(
                    0,
                    Flush.Entry(
                        replaced[0].start, adjacents[0].start, replaced[0].data
                    )
                )
                if start_idx > 0:
                    # the trimmed entry may have fewer leaves and itself merge with its neighbor
                    start_idx -= 1
                    replaced.insert(0, self.data[start_idx])
                    adjacents.insert(0, self.data[start_idx])
            if replaced[-1].end > adjacents[-1].end:
                adjacents.append(
                    Flush.Entry(
                        adjacents[-1].end, replaced[-1].end, replaced[-1].data
                    )
                )
                if end_idx < len(self.data):
                    # the trimmed entry may have fewer leaves and itself merge with its neighbor
                    replaced.append(self.data[end_idx])
                    adjacents.append(self.data[end_idx])
                    end_idx += 1

        for idx, entry in reversed([*enumerate(adjacents)]):
            if entry.leaf_count == 0:
                # no leaves left in this branch, remove
                adjacents.pop(idx)
                continue
            count = 0
            subentry = entry
            while count <= 1 and subentry is not None and type(subentry.data) is Flush:
                # make branches shallower by splicing out roots with only one child
                parent_entry = subentry
                count = 0
                subentry = None
                for subentry in parent_entry.flush_entries():
                    count += 1
                    if count > 1:
                        subentry = parent_entry
                        break
            if subentry is not entry:
                # some internodes were removed
                assert subentry is not None # can likely remove assignment to None above if this removed
                adjacents[idx] = subentry

        for idx, (left_adjacent, right_adjacent) in reversed([*enumerate(zip(adjacents[:-1], adjacents[1:]))]):

            # merge writes
            if (
                left_adjacent.age == self.age and
                right_adjacent.age == self.age and
                #type(left_adjacent.data) is Chunk and
                #type(right_adjacent.data) is Chunk and
                left_adjacent.end == right_adjacent.start
            ):
                left_adjacent.data = Chunk(
                    left_adjacent.start,
                    right_adjacent.end,
                    left_adjacent.chunk_data() + right_adjacent.chunk_data()
                )
                left_adjacent.end = right_adjacent.end
                adjacents.pop(idx+1)
                continue

            # merge branches with shared parents
            shared_parents = [
                left_parent for left_parent, right_parent
                in zip(left_adjacent.path, right_adjacent.path)
                if left_parent is right_parent
            ]
            if len(shared_parents) > 0:
                import pdb; pdb.set_trace()
                '''this code path has not been hit before; does it work?'''
                print(shared_parents)
            if len(shared_parents) > 0 and left_adjacent.height + len(left_adjacent.parents) - len(shared_parents) < self.max_height and right_adjacent.height + len(right_adjacent.parents) - len(shared_parents) < self.max_height:
                if left_adjacent.end != right_adjacent.start:
                    between_entry = Flush.Entry(
                        left_adjacent.end,
                        right_adjacent.start
                    )
                    if between_entry.leaf_count > 0:
                        # the shared root contains leaves in between that have been removed
                        continue
                print(f'Merging {len(left_adjacent.path)}:{left_adjacent.height}, {len(right_adjacent.path)}:{right_adjacent.height} -> {len(shared_path)}:{left_adjacent.height + len(leaf_adjacent.path) - len(shared_parents)}')
                left_adjacent.end = right_adjacent.end
                left_adjacent.leaf_count += right_adjacent.leaf_count
                left_adjacent.height += len(left_adjacent.path) - len(shared_parents)
                left_adjacent.path = shared_parents
                left_adjacent.data = shared_parents[-1]
                adjacents.pop(idx+1)

        self.leaf_count += sum((adjacent.leaf_count for adjacent in adjacents))
        self.leaf_count -= sum((old.leaf_count for old in replaced))
        self.max_height = self.leaf_count.bit_length()
        self.data[start_idx:end_idx] = adjacents
        self.height = max((entry.height for entry in self.data)) + 1
        #self.check_leaf_count(self.start, self.end)
    def write(self, offset, data):
        chunk = Chunk(offset, offset + len(data), data, age=self.age)
        return self.add(chunk)
    def read(self, start, max_end = float('inf')):
        # first idx with end > start
        idx = bisect.bisect_right([entry.end for entry in self.data], start)
        if idx == len(self.data):
            return bytes(4096)
        entry = self.data[idx]
        if entry.start > start:
            end = min(max_end, entry.start)
            return bytes(end - start)
        end = min(max_end, entry.end)
        if type(entry.data) is Flush:
            return entry.data.read(start, end)
        elif type(entry.data) is Chunk:
            datastart = start - entry.data.start
            dataend = end - entry.data.start
            return entry.data.data[datastart : dataend]
    #def check_leaf_count(self, start, end):
    #    leaf_count = 0
    #    wrapper = Flush.Entry(start, end, self)
    #    for entry in wrapper.flush_entries():
    #        if type(entry.data) is Flush:
    #            entry_leaf_count = entry.data.check_leaf_count(entry.start, entry.end)
    #            assert entry_leaf_count == entry.leaf_count
    #            leaf_count += entry_leaf_count
    #        else:
    #            leaf_count += entry.data.leaf_count
    #    assert leaf_count == wrapper.leaf_count
    #    if start == self.start and end == self.end:
    #        assert leaf_count == self.leaf_count
    #    return leaf_count



if __name__ == '__main__':
    import random
    random.seed(0)
    SIZE=4096
    comparison = bytearray(SIZE)
    store = Flush()
    def compare(store, comparison):
        offset = 0
        while offset < len(comparison):
            data = store.read(offset)[:len(comparison) - offset]
            assert data == comparison[offset:offset+len(data)]
            offset += len(data)
        #store.check_leaf_count(store.start, store.end)
        return True
    for flushes in range(1024):
        for writes in range(random.randint(1,16)):
            start = random.randint(0, SIZE)
            end = random.randint(0, SIZE)
            if end < start:
                start, end = end, start 
            end = (end + start) // 2
            size = end - start
            data = random.randint(0, (1<<(size*8))-1).to_bytes(size, 'little')
            store.write(start, data)
            comparison[start:end] = data
            #compare(store, comparison)
            #print('OK', flushes, writes)#, offset)
        compare(store, comparison)
        print('OK', len(store.data), 'x', store.height, 'count =', store.leaf_count, 'flushes =', flushes)#, writes)#, offset)
        store = Flush(prev_flush = store)
        compare(store, comparison)

Reply via email to