here is the failed implementation i tried to make today, and the
working implementation i had made some time ago. i haven't looked at
them side-by-side.
simple_beam_search is the one i tried to do today
try_multigen is the one from some time ago, my state of mind was
trying out parts for synthetic text adventure games
i believe both of them use a batch size of 1. i consider it not
complex to stabilize it with a batch size of 1, and then batch the top
batch_size items in the priority queue.
import torch
#import tqdm
import bisect
def simple_beam_search(model, prompt_ids, num_beams, tokenizer=None):#, verbose=True): #, batch_size): # batch size not hard to implement, process queue in chunks
'''This is a simple implementation of a beam search that selects the diverse highest probability results.'''
# we don't know for sure if it will be diverse.
prompt_ids = torch.tensor(prompt_ids, device=model.device)
prompt_len = len(prompt_ids)
#import pdb; pdb.set_trace()
# now we sustain a queue
# this initial generation could probably be simplified away by reordering the loop, maybe with a tiny inner loop
logits = model.forward(prompt_ids[None,:]).logits[:,-1][0]
token_ids = torch.argsort(logits)
token_probs = torch.softmax(logits, dim=0, dtype=float)
sort_uid = 0
queue = [(torch.tensor(1.0), torch.tensor(1.0), sort_uid, prompt_ids[prompt_len:], token_ids, token_probs)]; sort_uid += 1
result = []
# might need to reorder some parts
while len(result) < num_beams:
sort_prob, base_prob, _, input_ids, token_ids, token_probs = queue.pop()
if len(input_ids) and input_ids[-1] == model.config.eos_token_id:
if len(input_ids) > 1:
if tokenizer:
print('=>', round(next_base_prob.item()*10000)/100, '%', tokenizer.decode(input_ids))
result.append(input_ids[:-1])
continue
# now we have two things to add back to the queue: the same thing with lower prob, and the next thing
sort_prob = base_prob * token_probs[token_ids[-2]]
bisect.insort_right(queue, (sort_prob, base_prob, sort_uid, input_ids, token_ids[:-1], token_probs)); sort_uid += 1
next_token_id = token_ids[-1]
next_base_prob = base_prob * token_probs[token_ids[-1]]
# generate another token!
input_ids = torch.cat((input_ids, next_token_id[None]))
if tokenizer:
print(round(next_base_prob.item()*10000)/100, '%', tokenizer.decode(input_ids))
logits = model.forward(torch.cat((prompt_ids, input_ids))[None,:]).logits[:,-1][0]
token_ids = torch.argsort(logits)
token_probs = torch.softmax(logits, dim=0)#, dtype=float)
bisect.insort_right(queue, (next_base_prob, next_base_prob, sort_uid, input_ids, token_ids, token_probs)); sort_uid += 1
queue = queue[:128]
return result
import transformers
import torch
DEFAULT_PIPELINE = transformers.pipeline('text-generation', 'bigscience/bloomz-560m', dtype=torch.bfloat16) # bloomz: 560m, 1.1b, 1.7b, 3b, 7.1b, 176b ; mt0: small (300m), base, large, xl, xxl
# simplified would go here
# it associates text + keyvals together, and can continue with text
class Generated:
def __init__(self, text, token_ids = None, outputs = None, batch_idx = None, pipeline = None, model = None, tokenizer = None, **kwparams):
if pipeline is None:
pipeline = DEFAULT_PIPELINE
if model is None:
model = pipeline.model
if tokenizer is None:
tokenizer = pipeline.tokenizer
self.model = model
self.tokenizer = tokenizer
self.kwparams = kwparams
if batch_idx is not None and type(text) in (list, tuple):
text = text[batch_idx]
self._text = text
if token_ids is None:
token_ids = self.tokenizer(text, return_tensors='pt').input_ids # , attention_mask
assert token_ids.shape[0] == 1
token_ids = token_ids[0]
else:
if batch_idx is None:
assert len(token_ids.shape) == 1
elif type(token_ids) in (list, tuple):
token_ids = token_ids[batch_idx]
assert len(token_ids.shape) == 1
elif len(token_ids.shape) > 1:
assert len(token_ids.shape) == 2
token_ids = token_ids[batch_idx]
self._token_ids = token_ids
if outputs is None:
assert batch_idx is None
self.model.eval()
outputs = self.model(self._token_ids[None], use_cache=True, **self.kwparams)#, return_dict=True, output_hidden_states=True)
assert len(outputs.logits.shape) == 3
if batch_idx is None:
assert outputs.logits.shape[0] == 1
#self._outputs = outputs
self._logits = outputs.logits[0,-1].detach().clone()
self._past_key_values = outputs.past_key_values
else:
self._logits = outputs.logits[batch_idx,-1].detach().clone()
self._past_key_values = tuple([
(
past_key_values[0][batch_idx:batch_idx+1],
past_key_values[1][batch_idx:batch_idx+1]
)
for past_key_values in outputs.past_key_values
])
def __str__(self):
return self._text
def next_id_probs(self):
#logits = self._outputs.logits[0,-1]
probs, ids = self._logits.softmax(dim=-1).detach().to(torch.bfloat16).sort(descending=True)
for idx in range(len(probs)):
yield ids[idx], probs[idx]
def next_str_probs(self):
offset = len(str(self))
for id, prob in self.next_id_probs():
yield self.tokenizer.decode(torch.cat((self._token_ids, id[...,None]), dim=-1))[offset:], prob
def next_obj_probs(self, **kwparams):
for id, prob in self.next_id_probs():
suffix_ids = id[...,None]
token_ids = torch.cat((self._token_ids, suffix_ids), dim=-1)
suffix = self.tokenizer.decode(token_ids)[len(str(self)):]
obj = self(suffix, suffix_ids, **kwparams)
yield obj, prob
def __call__(self, suffix, suffix_ids = None, **kwparams):
kwparams = {**self.kwparams, **kwparams}
assert bool(suffix) or bool(suffix_ids) # for now, could return self or pass str(self) with other kwparams if both are falsey
if suffix_ids:
token_ids = torch.cat((self._token_ids, suffix_ids))
if suffix:
text = str(self) + suffix
if suffix_ids:
#assert self.tokenizer.encode(text) == token_ids # disabled: the same text can have multiple encodings
#assert self.tokenizer.decode(token_ids) == text # disabled: token ids can represent different text concatenated than separate
decoded_text = self.tokenizer.decode(token_ids)
assert decoded_text.replace("'",'').replace(' ','') == text.replace("'",'').replace(' ','')
text = decoded_text
else:
token_ids = self.tokenizer(text, return_tensors='pt').input_ids
assert token_ids.shape[0] == 1
token_ids = token_ids[0]
suffix_ids = token_ids[len(self._token_ids):]
else:
text = self.tokenizer.decode(token_ids)
self.model.eval()
outputs = self.model(suffix_ids[None], past_key_values=self._past_key_values, use_cache=True, **kwparams)
#outputs = self.model(token_ids[None], use_cache=False, **kwparams)
obj = Generated(text = text, token_ids = token_ids, outputs = outputs, pipeline = None, model = self.model, tokenizer = self.tokenizer, **kwparams)
return obj
@staticmethod
def parallel(objects, suffices, suffix_ids = None, suffix_attn_masks = None, **kwparams):
batchsize = len(objects)
obj_kwparams = {}
for obj in objects:
obj_kwparams.update(obj.kwparams)
assert obj.model is objects[0].model
kwparams = {**obj_kwparams, **kwparams}
assert bool(suffices) or bool(suffix_ids) # notimplementedatthistime
if suffix_attn_masks:
assert suffix_ids
if suffix_ids:
if suffix_attn_masks:
# quick thing, have not tested
token_ids = [torch.cat((objects[idx]._token_ids, suffix_ids[idx][suffix_attn_masks[idx] != 0])) for idx in range(batchsize)]
else:
token_ids = [torch.cat((objects[idx]._token_ids, suffix_ids[idx])) for idx in range(batchsize)]
if suffices:
texts = [
str(objects[idx]) + suffices[idx]
for idx in range(batchsize)
]
if suffix_ids:
for idx in range(batchsize):
assert obj.tokenizer.decode(token_ids[idx]) == texts[idx]
else:
token_ids = obj.tokenizer(texts, padding = False, return_tensors='pt').input_ids
assert token_ids.shape[0] == batchsize
suffix_ids_maxlen = max([len(token_ids[idx]) - len(objects[idx]._token_ids) for idx in range(batchsize)])
suffix_ids = torch.empty((batchsize, suffix_ids_maxlen), dtype=int)
suffix_attn_masks = torch.ones_like(suffix_ids)
for idx in range(batchsize):
suffix_ids_length = len(token_ids[idx]) - len(objects[idx]._token_ids)
suffix_ids[idx][-suffix_ids_length:] = token_ids[idx][-suffix_ids_length:]
suffix_attn_masks[idx][:-suffix_ids_length] = 0
else: # suffices
texts = [
objs[idx].tokenizer.decode(token_ids[idx])
for idx in range(batchsize)
]
past_key_values = tuple([
(
torch.cat(
[
objects[idx]._past_key_values[layer][0]
for idx in range(batchsize)
],
dim=0),
torch.cat(
[
objects[idx]._past_key_values[layer][1]
for idx in range(batchsize)
],
dim=0)
)
for layer in range(len(objects[0]._past_key_values))
])
objects[0].model.eval()
outputs = objects[0].model(input_ids=suffix_ids, past_key_values=past_key_values, attention_mask=suffix_attn_masks, use_cache=True, **kwparams)
return [
Generated(text = texts, token_ids = token_ids, outputs = outputs, batch_idx = idx, model = objects[idx].model, tokenizer = objects[idx].tokenizer, **kwparams)
for idx in range(batchsize)
]
@staticmethod
def next_objs_probs(objs, **kwparams):
batchsize = len(objs)
id_probs_its = [iter(obj.next_id_probs()) for obj in objs]
for item in range(len(objs[0]._logits)):
# the first item in each list may be the one to parallelize
next_id_probs = [next(it) for it in id_probs_its]
suffix_ids = torch.stack([id[...,None] for id, prob in next_id_probs])
token_ids = [torch.cat((objs[idx]._token_ids, suffix_ids[idx]), dim=-1) for idx in range(batchsize)]
suffices = [objs[idx].tokenizer.decode(token_ids[idx])[len(str(objs[idx])):] for idx in range(batchsize)]
next_objs = Generated.parallel(objs, suffices, suffix_ids, **kwparams) # suffix_attn_masks
yield [(next_objs[idx], next_id_probs[idx][1]) for idx in range(batchsize)] # yield obj, prob
class Multiple(Generated):
def __init__(self, text, eos_text, *extra_eos_texts, params = [], **kwparams):
super().__init__(text, *params, **kwparams)
self.eos_texts = [eos_text, *extra_eos_texts]
def __iter__(self):
#queue = [(1, 1, None, self, iter(Generated.next_objs_probs([self])))]
queue = [(1, 1, None, self, iter(self.next_obj_probs()))]
#prob_scale = 1
# thinking of control flow if whole queue done in parallel
# atm, it compares each text, and then continues the ones that are not complete
# for continuing, it could continue every text, and recreate the queue with a single sort pass and doubled content it, truncate it and reverse it
# when doing that, the pass to continue every text would bump into completed texts.
while len(queue):
max_prob, base_prob, prev_base, base, it = queue.pop()
base_text = str(base)
eos_found = False
for eos_text in self.eos_texts:
if base_text[len(str(self)):].endswith(eos_text):
generated = base_text[len(str(self)):-len(eos_text)].strip()
if len(generated):
yield generated, base_prob
#else:
# prob_scale *= ... would need to track cumulative prob to do accurately
eos_found = True
break
if eos_found:
continue
if max_prob == base_prob:
print(len(queue), round(float(base_prob*100),4), str(base).strip()[:80], end='\r', flush=True)
#next_base, next_prob = next(it)[0]
next_base, next_prob = next(it)
next_prob = next_prob * base_prob
new_items = [
#(next_prob, next_prob, base, next_base, iter(Generated.next_objs_probs([next_base]))),
(next_prob, next_prob, base, next_base, iter(next_base.next_obj_probs())),
(next_prob, base_prob, prev_base, base, it) # try this one again when prob drops
]
for item in new_items:
if len(queue) < 2048 and item[0] > 0.0001:
queue.append(item)
elif queue[0][0] < item[0]:
queue[0] = item
queue.sort(key = lambda tup: tup[:2])
if __name__ == '__main__':
print("This example is to the point where it is outpacing damaged parts of Karl's cognition that he needs to heal.")
print("Is there any possibility to add training wheels to it, so the user learns a little relevance?")
# thinking user could learn to provide 2 missing concepts when prompted with 1 example
# ideally to keep growing so it is user-created content rather than ai-created
try:
with torch.no_grad():
#multiple = Multiple('Places you might find in a city:', '</s>', ',']#, '.')
multiple = Multiple('Things you might find in a magic closet:', '</s>', ',')#, '.')
#multiple = Multiple('Some striking environments for a fantasy, sci-hi, or historical scene:', '</s>', ',')#, '.')
#multiple = Multiple('Names for alien home planets:', '</s>', ',') # these ones don't work and would need examples and/or filtering
#elems = 'Dizquakin, The Wandering Rift, Vrazz, Mortudin, New Star Haven, Fradinabble, The Galactic Bridge, Wuspercin, The Twince Nebula, Tozquarn, Charred Call, Cialla, The Phoenix Moon, Birnat, The Gaping Maw, Cruftspace, Musprit'.split(', ') # not sure about crafting all these examples atm
#import random
#random.shuffle(elems)
#multiple = Multiple(' dizquakin, the wandering rift, vrazz, mortudin, new star haven, fradinabble, the galactic bridge, wuspercin, the twince nebula, tozquarn, charred call, cialla, birnat, the gaping maw, cruftspace, musprit,', '</s>', ',')
#multiple = Multiple(' ' + ', '.join(elems) + ',', '</s>', ',', '.')
remaining = 1
for completion, prob in iter(multiple):
remaining -= float(prob)
print(f'{round(float(remaining)*100,4)} + {round(float(prob)*100,4)}: {completion} {" " * len(str(multiple)[:80])}')
finally:
print("This example is to the point where it is outpacing damaged parts of Karl's cognition that he needs to heal.")
print("Is there any possibility to add training wheels to it, so the user learns a little relevance, and must learn to be able to do what the script does?")