here is the failed implementation i tried to make today, and the
working implementation i had made some time ago. i haven't looked at
them side-by-side.

simple_beam_search is the one i tried to do today
try_multigen is the one from some time ago, my state of mind was
trying out parts for synthetic text adventure games

i believe both of them use a batch size of 1. i consider it not
complex to stabilize it with a batch size of 1, and then batch the top
batch_size items in the priority queue.
import torch
#import tqdm
import bisect
def simple_beam_search(model, prompt_ids, num_beams, tokenizer=None):#, verbose=True): #, batch_size): # batch size not hard to implement, process queue in chunks
    '''This is a simple implementation of a beam search that selects the diverse highest probability results.'''
    # we don't know for sure if it will be diverse.

    prompt_ids = torch.tensor(prompt_ids, device=model.device)
    prompt_len = len(prompt_ids)

    #import pdb; pdb.set_trace()
    
    # now we sustain a queue
            # this initial generation could probably be simplified away by reordering the loop, maybe with a tiny inner loop
    logits = model.forward(prompt_ids[None,:]).logits[:,-1][0]
    token_ids = torch.argsort(logits)
    token_probs = torch.softmax(logits, dim=0, dtype=float)
    sort_uid = 0
    queue = [(torch.tensor(1.0), torch.tensor(1.0), sort_uid, prompt_ids[prompt_len:], token_ids, token_probs)]; sort_uid += 1
    result = []
            # might need to reorder some parts
    while len(result) < num_beams:
        sort_prob, base_prob, _, input_ids, token_ids, token_probs = queue.pop()

        if len(input_ids) and input_ids[-1] == model.config.eos_token_id:
            if len(input_ids) > 1:
                if tokenizer:
                    print('=>', round(next_base_prob.item()*10000)/100, '%', tokenizer.decode(input_ids))
                result.append(input_ids[:-1])
            continue
        
        # now we have two things to add back to the queue: the same thing with lower prob, and the next thing
        sort_prob = base_prob * token_probs[token_ids[-2]]
        bisect.insort_right(queue, (sort_prob, base_prob, sort_uid, input_ids, token_ids[:-1], token_probs)); sort_uid += 1
        next_token_id = token_ids[-1]
        next_base_prob = base_prob * token_probs[token_ids[-1]]

        # generate another token!
        input_ids = torch.cat((input_ids, next_token_id[None]))
        if tokenizer:
            print(round(next_base_prob.item()*10000)/100, '%', tokenizer.decode(input_ids))
        logits = model.forward(torch.cat((prompt_ids, input_ids))[None,:]).logits[:,-1][0]
        token_ids = torch.argsort(logits)
        token_probs = torch.softmax(logits, dim=0)#, dtype=float)
        bisect.insort_right(queue, (next_base_prob, next_base_prob, sort_uid, input_ids, token_ids, token_probs)); sort_uid += 1
        queue = queue[:128]
    return result
import transformers
import torch


DEFAULT_PIPELINE = transformers.pipeline('text-generation', 'bigscience/bloomz-560m', dtype=torch.bfloat16) # bloomz: 560m, 1.1b, 1.7b, 3b, 7.1b, 176b ; mt0: small (300m), base, large, xl, xxl

# simplified would go here
# it associates text + keyvals together, and can continue with text

class Generated:
    def __init__(self, text, token_ids = None, outputs = None, batch_idx = None, pipeline = None, model = None, tokenizer = None, **kwparams):
        if pipeline is None:
            pipeline = DEFAULT_PIPELINE
        if model is None:
            model = pipeline.model
        if tokenizer is None:
            tokenizer = pipeline.tokenizer
        self.model = model
        self.tokenizer = tokenizer
        self.kwparams = kwparams
        if batch_idx is not None and type(text) in (list, tuple):
            text = text[batch_idx]
        self._text = text
        if token_ids is None:
            token_ids = self.tokenizer(text, return_tensors='pt').input_ids # , attention_mask
            assert token_ids.shape[0] == 1
            token_ids = token_ids[0]
        else:
            if batch_idx is None:
                assert len(token_ids.shape) == 1
            elif type(token_ids) in (list, tuple):
                token_ids = token_ids[batch_idx]
                assert len(token_ids.shape) == 1
            elif len(token_ids.shape) > 1:
                assert len(token_ids.shape) == 2
                token_ids = token_ids[batch_idx]
        self._token_ids = token_ids
        if outputs is None:
            assert batch_idx is None
            self.model.eval()
            outputs = self.model(self._token_ids[None], use_cache=True, **self.kwparams)#, return_dict=True, output_hidden_states=True)
        assert len(outputs.logits.shape) == 3
        if batch_idx is None:
            assert outputs.logits.shape[0] == 1
            #self._outputs = outputs
            self._logits = outputs.logits[0,-1].detach().clone()
            self._past_key_values = outputs.past_key_values
        else:
            self._logits = outputs.logits[batch_idx,-1].detach().clone()
            self._past_key_values = tuple([
                (
                    past_key_values[0][batch_idx:batch_idx+1],
                    past_key_values[1][batch_idx:batch_idx+1]
                )
                for past_key_values in outputs.past_key_values
            ])
    def __str__(self):
        return self._text
    def next_id_probs(self):
        #logits = self._outputs.logits[0,-1]
        probs, ids = self._logits.softmax(dim=-1).detach().to(torch.bfloat16).sort(descending=True)
        for idx in range(len(probs)):
            yield ids[idx], probs[idx]
    def next_str_probs(self):
        offset = len(str(self))
        for id, prob in self.next_id_probs():
            yield self.tokenizer.decode(torch.cat((self._token_ids, id[...,None]), dim=-1))[offset:], prob
    def next_obj_probs(self, **kwparams):
        for id, prob in self.next_id_probs():
            suffix_ids = id[...,None]
            token_ids = torch.cat((self._token_ids, suffix_ids), dim=-1)
            suffix = self.tokenizer.decode(token_ids)[len(str(self)):]
            obj = self(suffix, suffix_ids, **kwparams)
            yield obj, prob

    def __call__(self, suffix, suffix_ids = None, **kwparams):
        kwparams = {**self.kwparams, **kwparams}
        assert bool(suffix) or bool(suffix_ids) # for now, could return self or pass str(self) with other kwparams if both are falsey
        if suffix_ids:
            token_ids = torch.cat((self._token_ids, suffix_ids))
        if suffix:
            text = str(self) + suffix
            if suffix_ids:
                #assert self.tokenizer.encode(text) == token_ids # disabled: the same text can have multiple encodings
                #assert self.tokenizer.decode(token_ids) == text # disabled: token ids can represent different text concatenated than separate
                decoded_text = self.tokenizer.decode(token_ids)
                assert decoded_text.replace("'",'').replace(' ','') == text.replace("'",'').replace(' ','')
                text = decoded_text
            else:
                token_ids = self.tokenizer(text, return_tensors='pt').input_ids
                assert token_ids.shape[0] == 1
                token_ids = token_ids[0]
                suffix_ids = token_ids[len(self._token_ids):]
        else:
            text = self.tokenizer.decode(token_ids)
        self.model.eval()
        outputs = self.model(suffix_ids[None], past_key_values=self._past_key_values, use_cache=True, **kwparams)
        #outputs = self.model(token_ids[None], use_cache=False, **kwparams)
        obj = Generated(text = text, token_ids = token_ids, outputs = outputs, pipeline = None, model = self.model, tokenizer = self.tokenizer, **kwparams)
        return obj

    @staticmethod
    def parallel(objects, suffices, suffix_ids = None, suffix_attn_masks = None, **kwparams):
        batchsize = len(objects)
        obj_kwparams = {}
        for obj in objects:
            obj_kwparams.update(obj.kwparams)
            assert obj.model is objects[0].model
        kwparams = {**obj_kwparams, **kwparams}
        assert bool(suffices) or bool(suffix_ids) # notimplementedatthistime
        if suffix_attn_masks:
            assert suffix_ids
        if suffix_ids:
            if suffix_attn_masks:
                # quick thing, have not tested
                token_ids = [torch.cat((objects[idx]._token_ids, suffix_ids[idx][suffix_attn_masks[idx] != 0])) for idx in range(batchsize)]
            else:
                token_ids = [torch.cat((objects[idx]._token_ids, suffix_ids[idx])) for idx in range(batchsize)]
        if suffices:
            texts = [
                str(objects[idx]) + suffices[idx]
                for idx in range(batchsize)
            ]
            if suffix_ids:
                for idx in range(batchsize):
                    assert obj.tokenizer.decode(token_ids[idx]) == texts[idx]
            else:
                token_ids = obj.tokenizer(texts, padding = False, return_tensors='pt').input_ids
                assert token_ids.shape[0] == batchsize
                suffix_ids_maxlen = max([len(token_ids[idx]) - len(objects[idx]._token_ids) for idx in range(batchsize)])
                suffix_ids = torch.empty((batchsize, suffix_ids_maxlen), dtype=int)
                suffix_attn_masks = torch.ones_like(suffix_ids)
                for idx in range(batchsize):
                    suffix_ids_length = len(token_ids[idx]) - len(objects[idx]._token_ids)
                    suffix_ids[idx][-suffix_ids_length:] = token_ids[idx][-suffix_ids_length:]
                    suffix_attn_masks[idx][:-suffix_ids_length] = 0
        else: # suffices
            texts = [
                objs[idx].tokenizer.decode(token_ids[idx])
                for idx in range(batchsize)
            ]
        past_key_values = tuple([
            (
                torch.cat(
                    [
                        objects[idx]._past_key_values[layer][0]
                        for idx in range(batchsize)
                    ],
                    dim=0),
                torch.cat(
                    [
                        objects[idx]._past_key_values[layer][1]
                        for idx in range(batchsize)
                    ],
                    dim=0)
            )
            for layer in range(len(objects[0]._past_key_values))
        ])
        objects[0].model.eval()
        outputs = objects[0].model(input_ids=suffix_ids, past_key_values=past_key_values, attention_mask=suffix_attn_masks, use_cache=True, **kwparams)
        return [
            Generated(text = texts, token_ids = token_ids, outputs = outputs, batch_idx = idx, model = objects[idx].model, tokenizer = objects[idx].tokenizer, **kwparams)
            for idx in range(batchsize)
        ]
    @staticmethod
    def next_objs_probs(objs, **kwparams):
        batchsize = len(objs)
        id_probs_its = [iter(obj.next_id_probs()) for obj in objs]
        for item in range(len(objs[0]._logits)):
            # the first item in each list may be the one to parallelize
            next_id_probs = [next(it) for it in id_probs_its]
            suffix_ids = torch.stack([id[...,None] for id, prob in next_id_probs])
            token_ids = [torch.cat((objs[idx]._token_ids, suffix_ids[idx]), dim=-1) for idx in range(batchsize)]
            suffices = [objs[idx].tokenizer.decode(token_ids[idx])[len(str(objs[idx])):] for idx in range(batchsize)]
            next_objs = Generated.parallel(objs, suffices, suffix_ids, **kwparams) # suffix_attn_masks
            yield [(next_objs[idx], next_id_probs[idx][1]) for idx in range(batchsize)] # yield obj, prob


class Multiple(Generated):
    def __init__(self, text, eos_text, *extra_eos_texts, params = [], **kwparams):
        super().__init__(text, *params, **kwparams)
        self.eos_texts = [eos_text, *extra_eos_texts]
    def __iter__(self):
        #queue = [(1, 1, None, self, iter(Generated.next_objs_probs([self])))]
        queue = [(1, 1, None, self, iter(self.next_obj_probs()))]
        #prob_scale = 1
                # thinking of control flow if whole queue done in parallel
                # atm, it compares each text, and then continues the ones that are not complete
                #   for continuing, it could continue every text, and recreate the queue with a single sort pass and doubled content it, truncate it and reverse it
                #   when doing that, the pass to continue every text would bump into completed texts.
        while len(queue):
            max_prob, base_prob, prev_base, base, it = queue.pop()
            base_text = str(base)
            eos_found = False
            for eos_text in self.eos_texts:
                if base_text[len(str(self)):].endswith(eos_text):
                    generated = base_text[len(str(self)):-len(eos_text)].strip()
                    if len(generated):
                        yield generated, base_prob
                    #else:
                    #    prob_scale *= ... would need to track cumulative prob to do accurately
                    eos_found = True
                    break
            if eos_found:
                continue
            if max_prob == base_prob:
                print(len(queue), round(float(base_prob*100),4), str(base).strip()[:80], end='\r', flush=True)

            #next_base, next_prob = next(it)[0]
            next_base, next_prob = next(it)
            next_prob = next_prob * base_prob
            new_items = [
                #(next_prob, next_prob, base, next_base, iter(Generated.next_objs_probs([next_base]))),
                (next_prob, next_prob, base, next_base, iter(next_base.next_obj_probs())),
                (next_prob, base_prob, prev_base, base, it) # try this one again when prob drops
            ]
            for item in new_items:
                if len(queue) < 2048 and item[0] > 0.0001:
                    queue.append(item)
                elif queue[0][0] < item[0]:
                    queue[0] = item
                queue.sort(key = lambda tup: tup[:2])

if __name__ == '__main__':
    print("This example is to the point where it is outpacing damaged parts of Karl's cognition that he needs to heal.")
    print("Is there any possibility to add training wheels to it, so the user learns a little relevance?")
        # thinking user could learn to provide 2 missing concepts when prompted with 1 example
        # ideally to keep growing so it is user-created content rather than ai-created
    try:
        with torch.no_grad():
            #multiple = Multiple('Places you might find in a city:', '</s>', ',']#, '.')
            multiple = Multiple('Things you might find in a magic closet:', '</s>', ',')#, '.')
            #multiple = Multiple('Some striking environments for a fantasy, sci-hi, or historical scene:', '</s>', ',')#, '.')
            #multiple = Multiple('Names for alien home planets:', '</s>', ',') # these ones don't work and would need examples and/or filtering
            #elems = 'Dizquakin, The Wandering Rift, Vrazz, Mortudin, New Star Haven, Fradinabble, The Galactic Bridge, Wuspercin, The Twince Nebula, Tozquarn, Charred Call, Cialla, The Phoenix Moon, Birnat, The Gaping Maw, Cruftspace, Musprit'.split(', ') # not sure about crafting all these examples atm
            #import random
            #random.shuffle(elems)
            #multiple = Multiple(' dizquakin, the wandering rift, vrazz, mortudin, new star haven, fradinabble, the galactic bridge, wuspercin, the twince nebula, tozquarn, charred call, cialla, birnat, the gaping maw, cruftspace, musprit,', '</s>', ',')
            #multiple = Multiple(' ' + ', '.join(elems) + ',', '</s>', ',', '.')
            remaining = 1
            for completion, prob in iter(multiple):
                remaining -= float(prob)
                print(f'{round(float(remaining)*100,4)} + {round(float(prob)*100,4)}: {completion} {"  " * len(str(multiple)[:80])}')
    finally:
        print("This example is to the point where it is outpacing damaged parts of Karl's cognition that he needs to heal.")
        print("Is there any possibility to add training wheels to it, so the user learns a little relevance, and must learn to be able to do what the script does?")
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many
  • Re: [ot][spa... Undescribed Horrific Abuse, One Victim & Survivor of Many

Reply via email to