i'm working on the matrix permutations inside huggingface perceiver
and the major torch implementation of efficient attention.  i have two
scripts to call them, to step through and map the offsets. it's very
hard for me to think about the axis permutations.

scripts are attached.

perceiver_loader.py also functions as an interactive model loader for
the model generated by the line in the previous email.  using it on a
trained model, one can see that the training fits to many thousands of
numbers but still fails on rare numbers especially small numbers which
only have so many examples in the data. it also fails if the data
input format changes, such as adding the word 'and' or a hyphen.
try:
    import patch_pytorch
except:
    print('failed to find patch_pytorch, may crash on rasbpi. github/xloem/mempickle')
import math
thousands_names = ' thousand million billion'.split(' ')
numeral_names = 'zero one two three four five six seven eight nine'.split(' ')
tens_names = 'zero ten twenty thirty forty fifty sixty seventy eighty ninety'.split(' ')
teens_names = 'ten eleven twelve thirteen fourteen fifteen sixteen seventeen eighteen nineteen'.split(' ')

# can we convert between words and numbers
def number_to_word(num):
    num = int(num)
    if num == 0:
        return 'zero'
    result = ''
    prefix = ''
    suffix = ''
    if num < 0:
        prefix += 'negative '
        num = -num
    places = int(math.log10(num)) + 1
    for digit in range(0, places, 3):
        value = num % 1000
        num //= 1000
        if value == 0:
            continue
        hundred = value // 100
        ten = (value % 100) // 10
        one = value % 10
        part = ''
        if hundred > 0:
            part += numeral_names[hundred] + ' hundred'
        if ten == 1:
            if len(part):
                part += ' '
            part += teens_names[one]
        else:
            if ten > 0:
                if len(part):
                    part += ' '
                part += tens_names[ten]
            if one > 0:
                if len(part):
                    part += ' '
                part += numeral_names[one]
        if digit > 0 and len(part):
            part += ' ' + thousands_names[digit // 3]
        if len(suffix):
            part += ' '
        suffix = part + suffix
    return prefix + suffix


import transformers, torch

class Model(transformers.PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.input_preprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor(config)
        self.decoder = transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder(
            config,
            output_num_channels = config.d_latents,
            output_index_dims = config.max_position_embeddings,
            num_channels = config.d_model,
            qk_channels = config.qk_channels,
            v_channels = config.d_model,
            num_heads = config.num_decoder_heads,
            use_query_residual = False,
            final_project = False,
            trainable_position_encoding_kwargs = dict(
                num_channels = self.input_preprocessor.num_channels,
                index_dims = config.max_position_embeddings
            ),
        )
        self.perceiver = transformers.PerceiverModel(
            config,
            decoder = self.decoder,
            input_preprocessor = self.input_preprocessor,
        )
        self.output_postprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverEmbeddingDecoder(config)

        self.post_init()
    def forward(self, inputs=None, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, labels=None):#, return_dict=None, input_ids=None):
        outputs = self.perceiver(
                inputs=inputs,
                attention_mask=attention_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=False,#return_dict,
        )

        logits = self.output_postprocessor(
                #outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
                outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
        )

        loss = None
        if labels is not None:
            loss = torch.nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1))
        
        output = (logits,) + outputs[1:] # outputs[2:]
        if loss is None:
            return output
        else:
            return ((loss,) + output)

config = transformers.PerceiverConfig()
config.num_decoder_heads = config.num_cross_attention_heads
config.num_self_attends_per_block = 3#6#13#26#6
config.max_position_embeddings = 96
config.d_model = 96#384#768#128
config.d_latents = 160#640#1280#256
config.vocab_size = 256
config.qk_channels = 256#8 * 32
config.v_channels = config.d_latents
print('Constructing model ...', flush=True)
model = Model.from_pretrained('words2nums')
config = model.config
#config.chunk_size_query = 16
#config.chunk_size_key = 16
## maybe: per-process vmem for low-end systems; https://github.com/xloem/mempickle
#import pytorch_tensormap
#mmap_params = pytorch_tensormap.PyTorchMap()
#mmap_params.write(model.state_dict())
#model.load_state_dict(mmap_params.read(writeable = True))

import torch

model.eval()
while True:
    word = input('input word number: ')
    word_data = torch.frombuffer(bytearray(word, 'iso-8859-1').ljust(config.max_position_embeddings, b'\x9c'), dtype=torch.int8).to(torch.long).view((1, config.max_position_embeddings))
    word_mask = (word_data != -100).to(torch.float32)
    word_data[word_data == -100] = 32
    logits, output = model(inputs=word_data, attention_mask=word_mask)
    numbers = logits[0].detach().argmax(dim=1).to(torch.uint8).cpu().numpy().tobytes()
    numbers = numbers[:numbers.find(b'.')]
    print(numbers)

try:
    import patch_pytorch
except:
    print('failed to find patch_pytorch, may crash on rasbpi. github/xloem/mempickle')
import math
thousands_names = ' thousand million billion'.split(' ')
numeral_names = 'zero one two three four five six seven eight nine'.split(' ')
tens_names = 'zero ten twenty thirty forty fifty sixty seventy eighty ninety'.split(' ')
teens_names = 'ten eleven twelve thirteen fourteen fifteen sixteen seventeen eighteen nineteen'.split(' ')

# can we convert between words and numbers
def number_to_word(num):
    num = int(num)
    if num == 0:
        return 'zero'
    result = ''
    prefix = ''
    suffix = ''
    if num < 0:
        prefix += 'negative '
        num = -num
    places = int(math.log10(num)) + 1
    for digit in range(0, places, 3):
        value = num % 1000
        num //= 1000
        if value == 0:
            continue
        hundred = value // 100
        ten = (value % 100) // 10
        one = value % 10
        part = ''
        if hundred > 0:
            part += numeral_names[hundred] + ' hundred'
        if ten == 1:
            if len(part):
                part += ' '
            part += teens_names[one]
        else:
            if ten > 0:
                if len(part):
                    part += ' '
                part += tens_names[ten]
            if one > 0:
                if len(part):
                    part += ' '
                part += numeral_names[one]
        if digit > 0 and len(part):
            part += ' ' + thousands_names[digit // 3]
        if len(suffix):
            part += ' '
        suffix = part + suffix
    return prefix + suffix


import transformers, torch

class Model(transformers.PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.input_preprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor(config)
        self.decoder = transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder(
            config,
            output_num_channels = config.d_latents,
            output_index_dims = config.max_position_embeddings,
            num_channels = config.d_model,
            qk_channels = config.qk_channels,
            v_channels = config.d_model,
            num_heads = config.num_decoder_heads,
            use_query_residual = False,
            final_project = False,
            trainable_position_encoding_kwargs = dict(
                num_channels = self.input_preprocessor.num_channels,
                index_dims = config.max_position_embeddings
            ),
        )
        self.perceiver = transformers.PerceiverModel(
            config,
            decoder = self.decoder,
            input_preprocessor = self.input_preprocessor,
        )
        self.output_postprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverEmbeddingDecoder(config)

        self.post_init()
    def forward(self, inputs=None, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, labels=None):#, return_dict=None, input_ids=None):
        outputs = self.perceiver(
                inputs=inputs,
                attention_mask=attention_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=False,#return_dict,
        )

        logits = self.output_postprocessor(
                #outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
                outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
        )

        loss = None
        if labels is not None:
            loss = torch.nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1))
        
        output = (logits,) + outputs[1:] # outputs[2:]
        if loss is None:
            return output
        else:
            return ((loss,) + output)

config = transformers.PerceiverConfig()
config.num_decoder_heads = config.num_cross_attention_heads
config.num_self_attends_per_block = 3#6#13#26#6
config.max_position_embeddings = 96
config.d_model = 96#384#768#128
config.d_latents = 160#640#1280#256
config.vocab_size = 256
config.qk_channels = 256#8 * 32
config.v_channels = config.d_latents
print('Constructing model ...', flush=True)
model = Model.from_pretrained('words2nums')
config = model.config
config.chunk_size_query = 16
config.chunk_size_key = 16
## maybe: per-process vmem for low-end systems; https://github.com/xloem/mempickle
#import pytorch_tensormap
#mmap_params = pytorch_tensormap.PyTorchMap()
#mmap_params.write(model.state_dict())
#model.load_state_dict(mmap_params.read(writeable = True))

import torch

model.eval()
while True:
    word = input('input word number: ')
    word_data = torch.frombuffer(bytearray(word, 'iso-8859-1').ljust(config.max_position_embeddings, b'\x9c'), dtype=torch.int8).to(torch.long).view((1, config.max_position_embeddings))
    word_mask = (word_data != -100).to(torch.float32)
    word_data[word_data == -100] = 32
    logits, output = model(inputs=word_data, attention_mask=word_mask)
    numbers = logits[0].detach().argmax(dim=1).to(torch.uint8).cpu().numpy().tobytes()
    numbers = numbers[:numbers.find(b'.')]
    print(numbers)

Reply via email to