contains bugfixes regarding missed writes and branches when merging.
i started looking into append-only behavior (writing only to the last
offset) and performance seems much worse than other implementations (every
root is increasingly wide), maybe because i am rebalancing the tree by
simply attaching branches to the root.
it's a little confusing to me nowadays, how when the tree is balanced, if
using append only storage, this means adding internodes to the same storage
location as the root. forming an equivalency between an ideal tree and the
"update" tree where each newly added node as leaves are added all go in the
same "update" root node, could be helpful.
import bisect
class Chunk:
def __init__(self, start, end, data, height=0, leaf_count=1, age=0):
self.start = start
self.end = end
self.data = data
self.height = height
self.leaf_count = leaf_count
self.age = age
def __len__(self):
return self.end - self.start
def is_leaf(self):
return self.height == 0
class Flush(Chunk):
class Entry(Chunk):
def __init__(self, start, end, chunk, path = [], height=0, leaf_count=1):
super().__init__(start, end, chunk, height=height, leaf_count=leaf_count, age=chunk.age)
self.path = list(path)
self.path.append(self.data)
if not chunk.is_leaf() and (leaf_count==1 or height==0):
assert leaf_count==1 and height==0
self.leaf_count = 0
self.height = 1
for entry in self.flush_entries():
self.leaf_count += entry.leaf_count
self.height = max(self.height, entry.height + 1)
def flush_entries(self):
assert not self.data.is_leaf()
return (
Flush.Entry(max(entry.start, self.start), min(entry.end, self.end), entry.data, self.path)
for entry in self.data.data
if entry.start < self.end and entry.end > self.start
)
def chunk_data(self):
assert self.is_leaf()
return self.data.data[self.start - self.data.start : self.end - self.data.start]
def __init__(self, prev_flush = None):
if prev_flush is not None:
super().__init__(prev_flush.start, prev_flush.end, [], height=1, leaf_count=0, age=prev_flush.age+1)
self.max_height = prev_flush.leaf_count.bit_length()
prev_entry = Flush.Entry(self.start, self.end, prev_flush)
self.add(prev_entry)
else:
super().__init__(None, None, [], height=1, leaf_count=0)
self.max_height = 1
def add(self, *adjacents):
adjacents = list(adjacents)
if self.start is None:
self.start = adjacents[0].start
self.end = adjacents[-1].end
else:
self.start = min(self.start, adjacents[0].start)
self.end = max(self.end, adjacents[-1].end)
# expand adjacents that are too deep [should go after start_idx and end_idx are adjust, to find correct max_height easier]
idx = 0
while idx < len(adjacents):
entry = adjacents[idx]
if entry.height + 1 > self.max_height:
subadjacents = []
shallow_start = entry.start
shallow_end = shallow_start
for subentry in entry.flush_entries():
if subentry.height + 2 > self.max_height:
if shallow_end != shallow_start:
subadjacents.append(Flush.Entry(shallow_start, shallow_end, entry.data))
subadjacents.append(subentry)
shallow_start = subentry.end
shallow_end = subentry.end
if shallow_end != shallow_start:
subadjacents.append(Flush.Entry(shallow_start, shallow_end, entry.data))
adjacents[idx:idx+1] = subadjacents
else:
idx += 1
# first idx with end >= start
start_idx = bisect.bisect_left([entry.end for entry in self.data], adjacents[0].start)
# first idx with start > end
end_idx = bisect.bisect_right([entry.start for entry in self.data], adjacents[-1].end, start_idx)
replaced = self.data[start_idx:end_idx]
if len(replaced):
if replaced[0].start < adjacents[0].start:
adjacents.insert(
0,
Flush.Entry(
replaced[0].start, adjacents[0].start, replaced[0].data
)
)
if start_idx > 0 and replaced[0].end > adjacents[1].start:
# the trimmed entry may have fewer leaves and itself merge with its neighbor
start_idx -= 1
replaced.insert(0, self.data[start_idx])
adjacents.insert(0, self.data[start_idx])
if replaced[-1].end > adjacents[-1].end:
adjacents.append(
Flush.Entry(
adjacents[-1].end, replaced[-1].end, replaced[-1].data
)
)
if end_idx < len(self.data) and replaced[-1].start < adjacents[-2].end:
# the trimmed entry may have fewer leaves and itself merge with its neighbor
replaced.append(self.data[end_idx])
adjacents.append(self.data[end_idx])
end_idx += 1
for idx, entry in reversed(list(enumerate(adjacents))):
if entry.leaf_count == 0:
# no leaves left in this branch, remove
adjacents.pop(idx)
continue
count = 0
subentry = entry
while count <= 1 and not subentry.data.is_leaf():
# make branches shallower by splicing out roots with only one child
parent_entry = subentry
count = 0
for subentry in parent_entry.flush_entries():
count += 1
if count > 1:
subentry = parent_entry
break
if subentry is not entry:
# some internodes were removed
adjacents[idx] = subentry
idx = len(adjacents) - 1
while idx > 0:
idx -= 1
left_adjacent = adjacents[idx]
right_adjacent = adjacents[idx+1]
# merge writes
if (
left_adjacent.age == self.age and
right_adjacent.age == self.age and
left_adjacent.end == right_adjacent.start
):
left_adjacent.data = Chunk(
left_adjacent.start,
right_adjacent.end,
left_adjacent.chunk_data() + right_adjacent.chunk_data(),
age = self.age
)
left_adjacent.end = right_adjacent.end
adjacents.pop(idx+1)
continue
# merge branches with shared parents
shared_path = [
left_parent for left_parent, right_parent
in zip(left_adjacent.path, right_adjacent.path)
if left_parent is right_parent
]
if len(shared_path) > 0 and left_adjacent.height + len(left_adjacent.path) - len(shared_path) < self.max_height and right_adjacent.height + len(right_adjacent.path) - len(shared_path) < self.max_height:
if left_adjacent.end != right_adjacent.start:
assert left_adjacent.end < right_adjacent.start
between_entry = Flush.Entry(
left_adjacent.end,
right_adjacent.start,
shared_path[-1]
)
if between_entry.leaf_count > 0:
# the shared root contains leaves in between that have been removed
continue
print(f'Merging {len(left_adjacent.path)}:{left_adjacent.height}, {len(right_adjacent.path)}:{right_adjacent.height} -> {len(shared_path)}:{left_adjacent.height + len(left_adjacent.path) - len(shared_path)}')
merged = Flush.Entry(
left_adjacent.start,
right_adjacent.end,
chunk = shared_path[-1],
path = shared_path,
# letting Entry recalculate these is a quick way to handle overlap
#leaf_count = left_adjacent.leaf_count + right_adjacent.leaf_count,
#height = left_adjacent.height + len(left_adjacent.path) - len(shared_path)
)
#assert merged.leaf_count == merged.data.check_leaf_count(merged.start, merged.end)
adjacents[idx:idx+2] = [merged]
self.data[start_idx:end_idx] = adjacents
# using Entry to recalculate leaf_count is a quick-to-implement way to handle not double-counting chunks that span trimmed groups
proxy_entry = Flush.Entry(self.data[0].start, self.data[-1].end, self)
self.leaf_count = proxy_entry.leaf_count
self.height = proxy_entry.height
#self.height = max((entry.height for entry in self.data)) + 1
#self.check_leaf_count(self.start, self.end)
assert self.leaf_count > 0
self.max_height = self.leaf_count.bit_length()
#assert self.max_height >= self.height # oops this isn't met yet due to dependency order above, calculations used last value
def write(self, offset, data):
chunk = Chunk(offset, offset + len(data), data, age=self.age)
entry = Flush.Entry(offset, offset + len(data), chunk)
return self.add(entry)
def read(self, start, max_end = float('inf')):
# first idx with end > start
idx = bisect.bisect_right([entry.end for entry in self.data], start)
if idx == len(self.data):
return bytes(4096)
entry = self.data[idx]
if entry.start > start:
end = min(max_end, entry.start)
return bytes(end - start)
end = min(max_end, entry.end)
if entry.data.is_leaf():
datastart = start - entry.data.start
dataend = end - entry.data.start
return entry.data.data[datastart : dataend]
else:
return entry.data.read(start, end)
def check_leaf_count(self, start, end):
leaf_count = 0
height = 1
wrapper = Flush.Entry(start, end, self)
for entry in wrapper.flush_entries():
if type(entry.data) is Flush:
entry_leaf_count = entry.data.check_leaf_count(entry.start, entry.end)
assert entry_leaf_count == entry.leaf_count
leaf_count += entry_leaf_count
else:
leaf_count += entry.data.leaf_count
height = max(height, entry.height + 1)
assert leaf_count == wrapper.leaf_count
assert height == wrapper.height
if start == self.start and end == self.end:
assert leaf_count == self.leaf_count
assert height == self.height # oops not met yet
return leaf_count
def main():
import random
random.seed(0)
SIZE=4096
comparison = bytearray(SIZE)
#import mmap
#comparison = mmap.mmap(-1, SIZE)
store = Flush()
def compare(store, comparison):
offset = 0
while offset < len(comparison):
data = store.read(offset)[:len(comparison) - offset]
assert data == comparison[offset:offset+len(data)]
offset += len(data)
#store.check_leaf_count(store.start, store.end)
return True
for flushes in range(1024):
for writes in range(random.randint(1,16)):
start = random.randint(0, SIZE-1)
size = min(SIZE-start, random.randint(1, 1024))
#size = min(SIZE-start, random.randint(1, 128))
#start = len(comparison)
#size = random.randint(1,128)
end = start + size
data = random.getrandbits(size*8).to_bytes(size, 'little')
store.write(start, data)
comparison[start:end] = data
#compare(store, comparison)
#print('OK', len(store.data), 'x', store.height, '/', store.max_height, 'count =', store.leaf_count, 'flushes =', flushes, 'writes =', writes)#, offset)
compare(store, comparison)
print('OK', len(store.data), 'x', store.height, '/', store.max_height, 'count =', store.leaf_count, 'flushes =', flushes)#, writes)#, offset)
store = Flush(prev_flush = store)
compare(store, comparison)
if __name__ == '__main__':
main()
#import cProfile
#cProfile.run('main()')