Attached is code that trains a model that itself can produce a trained
model in roughly the same time as training the produced model manually
(down to randomness of initialization).

I wanted to keep working it to make something useful, but it's gotten
too tense for me to be near and I am stopping it. This is why it is
messy.

I try to do this on every now and then from year to year. I'm sure
many have already completed it. This is the farthest I've gotten
before abandoning it!

The basis of the theory is that a transformer performs as much
computation as there is length to the input and output sequences, so
you can increase the power by increasing the input and output size,
while keeping the size of the model as smaller. This lets you work
with the entirety of a large model using a smaller one, for example.

A novelty I like in this approach is how I labeled the training
weights by providing additional dimensions of data directly on the
inputs ("encoded" in the source), rather than using position
embeddings. Not sure what this is called, but it seems possibly more
effective than position embeddings to me as the first linear encoder
can learn arbitrary encoding forms but only needs to be as wide as the
model dimension.

Crazy Karl
import torch, torch.nn
import tqdm

class Transformer(torch.nn.Module):
    def __init__(self, d_model = 512, d_ff = 2048, n_layers = 6, d_input = None, d_output = None, is_causal = False, dtype=float, device='cpu'):
        super().__init__()
        self.encoding = d_input and torch.nn.Linear(d_input, d_model, dtype=dtype, device=device)
        self.decoding = d_output and torch.nn.Linear(d_model, d_output, dtype=dtype, device=device)
        if n_layers > 1:
            self.recoding = torch.nn.Linear(d_model*2, d_model, dtype=dtype, device=device)
        self.n_layers = n_layers
        self.d_input = d_input or d_model
        self.d_output = d_output or d_model
        self.is_causal = is_causal
            
        self.transformer = torch.nn.Transformer(
            d_model = d_model, 
            nhead = 1,
            num_decoder_layers = 1,#n_layers,
            dim_feedforward = d_ff,
            dropout = 0,#.1,
            activation = torch.nn.functional.silu,
            custom_encoder = lambda input, *params, **kwparams: input,
            layer_norm_eps = 1e-5,
            batch_first = True,
            norm_first = True, # some interest in removing norm
            bias = True,
            device = device,
            dtype = dtype,
        )

        self.raw_size = sum([p.view(-1).shape[-1] for p in self.parameters()])
        self.raw_dims = max([len(p.shape) for p in self.parameters()])
        self.encoded_size = [self.raw_size, 2 + self.raw_dims]
        #self.raw_names = list({n for n, p in self.named_parameters()})
        #self.raw_names.sort()
        #self.raw_names = { name: idx for idx, name in enumerate(self.raw_names) }
        #self.raw_depth = max([n.count('.')+1 for n, p in self.named_parameters()])
        ##self.raw_count = sum([1 for p in self.parameters()])
        #self.raw_names = list({name for n, p in self.named_parameters() for name in n.split('.')})
        #self.raw_names.sort()
        #self.raw_types = list({str(type(m)) for m in self.modules()})
        #self.raw_types.sort()
    def forward(self, inp):
        if self.encoding:
            inp = self.encoding(inp)
        out = self.transformer(
            src=torch.empty(list(inp.shape[:-2])+[0,inp.shape[-1]], dtype=inp.dtype, device=inp.device),
            tgt=inp,
            tgt_is_causal=self.is_causal,
        )
        if self.n_layers > 1:
            out = out[:,None,...]
            for layer in range(self.n_layers - 1):
                out = torch.cat([
                    out,
                    self.transformer(
                        #src=self.recoding(out[:,-1,...]),
                        #tgt=inp,
                        src=torch.empty(list(inp.shape[:-2])+[0,inp.shape[-1]], dtype=inp.dtype, device=inp.device),
                        #tgt=torch.cat([self.recoding(out[:,-1,...]),inp],dim=-2),
                        #tgt=self.recoding(out[:,-1,...]),
                        tgt=self.recoding(torch.cat([inp,out[:,-1,...]], dim=-1)),
                        tgt_is_causal=self.is_causal,
                    )[:,None,-out.shape[-2]:,:]
                ], dim=1)
        if self.decoding:
            out = self.decoding(out)
        return out
    CONST_DATA = 1
    CONST_ENCODED = 2
    #def forward_encoded(self, inp, model, encoded_0 = None):
    #    encoded_0 = encoded_0 or model.to_encoded()
    #    inputs_n = inputs.shape[0] * inputs.shape[1]
    #    context = torch.cat([
    #        torch.full([inputs_n, 1], CONST_DATA, device=inputs.device), # labels
    #        torch.arange(end=inputs.shape[0], device=inputs.device)[:,None,None].expand([inputs.shape[0], inputs.shape[1], 1]).reshape(inputs_n, 1), # batch ids
    #        inputs.view([inputs_n, inputs.shape[-1]]), # data
    #        torch.zeros([inputs_n, t_0.shape[-1] - inputs.shape[-1] - 1], device=inputs.device), # padding
    #    ])
    #    out = self(inp)
    #    model.from_encoded(out)
    def from_raw(self, raw):
        idx = 0
        with torch.no_grad():
            for p in self.parameters():
                p = p.view(-1)
                size = p.shape[-1]
                p[:] = raw[idx:idx+size]
                idx += size
    def to_raw(self, out=None):
        idx = 0
        for p in self.parameters():
            if out is None:
                out = torch.empty([self.raw_size], dtype=p.dtype, device=p.device)
            p = p.view(-1)
            size = p.shape[-1]
            out[idx:idx+size] = p
            idx += size
        return out
    def to_encoded(self, out=None):
        idx = 0
        ct = 0
        names = {}
        types = {}
        for name, param in self.named_parameters():
            if out is None:
                #out = torch.empty([self.raw_size, 1 + self.raw_dims], dtype=p.dtype, device=p.device)
                out = torch.empty(self.encoded_size, dtype=param.dtype, device=param.device)
                    ##### # noting some [harmed, like starchy confusion, common sadly] space, so taking space/distance [didn't take enough]
            flattened = param.view(-1)
            size = flattened.shape[0]
            dims = len(param.shape)
            #dims = list(param.shape)
            out[idx:idx+size,0] = flattened
            out[idx:idx+size,1] = ct
            # might need to swizzle this one
            out[idx:idx+size,2:-dims] = 0
            out[idx:idx+size,-dims:] = torch.stack(torch.meshgrid([torch.arange(dim) for dim in param.shape], indexing='ij'), dim=-1).view(-1,dims)
            ct += 1
            idx += size
        return out
    def from_encoded(self, encoded):
        return self.from_raw(encoded[:,0])

    def train_data(self, trainer, inputs, outputs, context = None, accuracy = 0.5):
        return trainer.train(model=self, inputs=inputs, outputs=outputs, context=context, accuracy=accuracy)

    #def train_encoded(self, trainer, model, inputs, outputs, context = None, accuracy = 0.5):
    #    return trainer.train(model


# - make it single-pass (so the loss is of the output's forward)
# - make it save state

# if we abstract the training approach into a class, this can then be passed to a function to train on e.g. data or result from encoded
# the function can then be first class, with the training approach a parameter, which might simplify the work some around concept of interest
#### cognition maybe simpler if train_rect put into Transformer class?
class RectTrainer:
    # calling it rectangular training when all data trained equally
        # i think it would be faster to not do it fully rectangular
    def __init__(self, optim = torch.optim.SGD, lr=1e-3, loss_fn = torch.nn.functional.mse_loss, layer_loss_fn = lambda loss, idx, total: loss ** idx):
        self.optim = optim
        #self.optim_kwparams = optim_kwparams
        self.lr = lr
        self.loss_fn = loss_fn
        self.layer_loss_fn = layer_loss_fn
    def train(self, model, inputs, outputs, context = None, accuracy = 0.5):
        lr = self.lr
        optim = self.optim(model.parameters(), lr=lr)#**self.optim_kwparams)
        last_ls = None
        # we might add acceleration once something basic works. oh!
        if context is not None:
            ctx_inputs = torch.cat([context,inputs],dim=-2)
            output_offset = context.shape[-2]
        else:
            ctx_inputs = inputs
            output_offset = 0
        with tqdm.tqdm(total=0.5, unit='ls') as pbar:
            while (last_ls is None or last_ls > accuracy) and lr:
                attempt = model(ctx_inputs)
                if len(attempt.shape) <=3:
                    ls = loss(attempt[:,output_offset:,:], outputs)
                    ls_item = ls.item()
                else:
                    # earlier layers are included in the loss so that depth can be changed to
                    # exchange computation time for output accuracy.
                    ls = [
                        self.loss_fn(attempt[:,idx,output_offset:,:], outputs)
                        for idx in range(attempt.shape[1])
                    ]
                    ls_item = ls[-1].item()
                    ls = torch.stack([self.layer_loss_fn(ls[idx], idx, attempt.shape[1]) for idx in range(attempt.shape[1])]).sum()
                if last_ls is not None and last_ls < ls_item: #last_ls - ls < ls * lr: # this is likely wrong, maybe change to last_ls < ls
                    lr /= 2
                    optim.param_groups[0]['lr'] = lr
                    pbar.display()
                else:
                    ls.backward()
                    optim.step()
                    optim.zero_grad()
                    last_ls = ls_item
                pbar.desc = f'lr={lr}, acc={last_ls}'
                if ls_item + accuracy > pbar.total:
                    pbar.total = ls_item + accuracy
                pbar.update(pbar.total - ls_item + accuracy - pbar.n)

#class TransformerTransformer(Transformer):
    #def __init__(self, d_model = 512, d_ff = 2048, n_layers = 6, d_input = None, d_output = None, dtype=float, device='cpu'):

if __name__ == '__main__':

    trainer = RectTrainer(lr=0.0001)
    t = Transformer(d_model=8, d_ff=16, n_layers=2, d_input=2, d_output=1, dtype=torch.float32, device=0)
    t_0 = t.to_encoded()
    t2 = Transformer(d_model=8, d_ff=16, n_layers=2, d_input=1+t.encoded_size[-1], d_output=1, dtype=torch.float32, device=0)

    inputs = torch.rand([3,16,2],device=0)
    inputs[:,:,1] = torch.arange(end=inputs.shape[-2], device=0)
    outputs = torch.rand([3,16,1],device=0)
    t.train_data(trainer, inputs, outputs, accuracy=0.1)
    #train_rect(
    #    model = t,
    #    inputs = inputs,
    #    outputs = outputs,
    #    accuracy = 0.1,
    #    lr = 0.0001,
    #) # move to next step. it works well enough. put one inside another.
    t_f = t.to_raw()



    CONST_DATA = 1
    CONST_ENCODED = 2

    # inputs are a sequence of vectors of properties, many of these sequences collected into a batch
    # dims are seq_idx, prop_idx
    inputs_n = inputs.shape[0] * inputs.shape[1]
    inputs2_data = torch.cat([
        torch.full([inputs_n, 1], CONST_DATA, device=inputs.device), # labels
        torch.arange(end=inputs.shape[0], device=inputs.device)[:,None,None].expand([inputs.shape[0], inputs.shape[1], 1]).reshape(inputs_n, 1), # batch ids
        inputs.view([inputs_n, inputs.shape[-1]]), # data
        torch.zeros([inputs_n, t_0.shape[-1] - inputs.shape[-1] - 1], device=inputs.device), # padding
    ], dim=-1)[None,...]
    inputs2_encoded = torch.cat([
        torch.full([t_0.shape[0], 1], CONST_ENCODED, device=inputs.device), # labels
        t_0.detach(), # data
    ], dim=-1)[None,...]
    outputs2 = t_f.detach()[None,:,None]
    t2.train_data(trainer, context=inputs2_data, inputs=inputs2_encoded, outputs=outputs2, accuracy=0.1)
    #train_rect(
    #    model = t2,
    #    context = inputs2_data,
    #    inputs = inputs2_encoded,
    #    outputs = outputs2,
    #    accuracy = 0.1,
    #    lr=0.0001,
    #)
    # next we want to try doing away with the first training step and train the second model directly on its data.
    # that will be in a new file.
    # this could be important because the loss is bound directly to the output. it may help ensure an approach is stabilized.
-- 
Liberationtech is public & archives are searchable from any major commercial 
search engine. Violations of list guidelines will get you moderated: 
https://lists.ghserv.net/mailman/listinfo/lt. Unsubscribe, change to digest 
mode, or change password by emailing [email protected].

Reply via email to