[ot][spam][crazy]

Undescribed Horrific Abuse, One Victim & Survivor of Many gmkarl at gmail.com
Tue Oct 3 16:53:16 PDT 2023


i used HuggingFace's LlamaModel which is just the Llama architecture.

i ignored the model's embedding map and passed my own embeddings which
i generated with a trainable linear module from the input model
weights and data.

similarly, i used a trainable linear layer for the output to generate
only 1 float per pass and used it in a causal manner. (you can train
on entire sequences, and then infer 1 float at a time).

I've trimmed the below code for conciseness so it may have an
inconsistency if i made a trimming mistake.

import os
import torch, transformers

class make_one_transformer(torch.nn.Module):
    def __init__(self, name, input_size, output_size=1,
complexity=None, load=True):
        super().__init__()
        self.name = name
        self.input_size = input_size
        self.output_size = output_size
        if complexity is None:
            complexity = max(output_size, input_size*16)

        # ratios from the default
        layers = max(complexity // 1024, 1)
        hidden_size = complexity // 8 // (2*layers) * 2*layers
        intermediate_size = int(complexity // 2.9)
        self.config = transformers.LlamaConfig(
            num_attention_heads=layers,
            num_hidden_layers=layers,
            num_key_value_heads=layers,
            vocab_size=output_size,
            max_position_embeddings=input_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
        )
        self.model = transformers.LlamaModel(self.config)

        self.embeddings = torch.nn.Linear(in_features = 1,
out_features = self.config.hidden_size)

        self.output_head = torch.nn.Linear(self.config.hidden_size,
self.output_size, bias=False)

        if load and os.path.exists(f'{name}.pt'):
            state_dict = torch.load(f'{name}.pt')
            self.iteration = state_dict.pop('iteration')
            self.load_state_dict(state_dict)
        else:
            self.iteration = 0

    def forward(self, input):
        # possible linear layer to map input to hidden size
        inputs_embeds = self.embeddings(input[...,None])
        output = self.model(inputs_embeds=inputs_embeds).last_hidden_state
        return self.output_head(output)

    def generate(self, input, length):
        # not totally sure about what past key vals needs, but it
looks like you could pass it straight from outputs and debug
        for idx in range(length):
            inputs_embeds = self.embeddings(input[...,None])
            logits = self.model(inputs_embeds=inputs_embeds).last_hidden_state
            output = self.output_head(logits) # since we do have an
output size, we'll want an lm_head
            input = torch.cat(input, output, dim=-1)
        return input[...,-length:]

    # this model no lm_head !
    # the above joke retained for humor was made before output_head was added


More information about the cypherpunks mailing list