Source code for mightypy.nlp.llm

"""
LLM
----
"""

import torch
from torch import nn
from tokkit import PyBytePairTokenizer
from tqdm import tqdm

_PYTORCH_DTYPE = torch.float32

[docs] class Word2Vec(nn.Module): def __init__(self, vocab_size, embedding_dims, device="cpu"): super().__init__() self.embedding = nn.Embedding( vocab_size, embedding_dims, device=device, dtype=_PYTORCH_DTYPE ) # V x E self.linear = nn.Linear(embedding_dims, vocab_size, device=device, dtype=_PYTORCH_DTYPE) # E x V
[docs] def forward(self, X: torch.Tensor): embeds = self.embedding(X) # (B, T, M) logits = self.linear(embeds) return logits
[docs] class BatchMultiHeadAttention(nn.Module): """This is more truthful to the paper""" def __init__( self, n_heads, d_model, d_value, d_key, masked=True, dropout_p=0.2, device="cpu" ): # From attention is all you need paper secton 3.2.2 Multi-Head Attention super().__init__() self._n_heads = n_heads self._d_model = d_model self._d_v = d_value self._d_k = d_key self._masked = masked self._device = device # TODO : this can be optimized by nn.Linear layer Will solve this in V2 self.w_q = nn.Parameter( torch.rand( self._n_heads, self._d_k, self._d_model, requires_grad=True, device=self._device, dtype=_PYTORCH_DTYPE ) * 1e-1 ) # (H, K, M) self.w_k = nn.Parameter( torch.rand( self._n_heads, self._d_k, self._d_model, requires_grad=True, device=self._device, dtype=_PYTORCH_DTYPE ) * 1e-1 ) # (H, K, M) self.w_v = nn.Parameter( torch.rand( self._n_heads, self._d_v, self._d_model, requires_grad=True, device=self._device, dtype=_PYTORCH_DTYPE ) * 1e-1 ) # (H, V, M) self.w_o = nn.Parameter( torch.rand( self._n_heads * self._d_v, self._d_model, requires_grad=True, device=self._device, dtype=_PYTORCH_DTYPE ) * 1e-1 ) # (H*V, M) self.resid_dropout = nn.Dropout(dropout_p)
[docs] def forward(self, X: torch.Tensor): B, T, M = X.size() X = X.unsqueeze(1) # (B, T, M) -> (B, 1, T, M) # Q transpose last two dims (H, K, M) -> (H, M, K) # (B, 1, T, M) x (H, M, K) = (B, H, T, K) q = X @ self.w_q.transpose(-2, -1) # print(X.shape, self.w_q.shape, q.shape) k = X @ self.w_k.transpose(-2, -1) # (B, H, T, K) v = X @ self.w_v.transpose(-2, -1) # (B, H, T, V) # scaled dot product # (B, H, T, K) x (B, H, K, T) := (B, H, T, T) scaled_dot_product = (q @ k.transpose(-2, -1)) / torch.sqrt( torch.tensor(self._d_k, dtype=torch.float32, device=self._device) ) if self._masked: # instead of regenerating everytime it can be regitered in the process # and can be picked like mask[:, :, :T, :T] # more efficient , might look at it in the future mask = ( torch.tril(torch.ones(T, T, device=self._device)).unsqueeze(0).unsqueeze(0) ) # (T, T) -> (1, T, T) -> (1, 1, T, T) scaled_dot_product = scaled_dot_product.masked_fill( mask == 0, float("-inf") ) scaled_dot_product_probs = torch.softmax( scaled_dot_product, dim=-1 ) # (B, H, T, T) # print("A ", scaled_dot_product_probs.shape, scaled_dot_product_probs.sum(dim=-1)) attn_out = ( scaled_dot_product_probs @ v ) # (B, H, T, T) x (B, H, T, V) = (B, H, T, V) attn_out = attn_out.permute(0, 2, 1, 3) # (B, H, T, V) -> (B, T, H, V) B, T, H, V = attn_out.shape out_projection = ( attn_out.reshape(B, T, H * V) @ self.w_o ) # (B, T, H * V) x (H * V, M)= (B, T, M) out_projection = self.resid_dropout(out_projection) return out_projection
[docs] class BatchMultiHeadAttentionV2(nn.Module): """This is modern implementation for more efficient processing""" def __init__( self, n_heads, d_model, d_value, d_key, masked=True, dropout_p=0.2, device="cpu" ): # From attention is all you need paper secton 3.2.2 Multi-Head Attention super().__init__() self._n_heads = n_heads self._d_model = d_model self._d_v = d_value self._d_k = d_key self._masked = masked self._device = device # In new approach all the heads outputs are generated in single/mono multiplication and # output is reshaped. Somehow this is more optimized :| , but less interpretable self.w_q = nn.Linear( self._d_model, self._d_k * self._n_heads, bias=False, device=self._device, dtype=_PYTORCH_DTYPE ) self.w_k = nn.Linear( self._d_model, self._d_k * self._n_heads, bias=False, device=self._device, dtype=_PYTORCH_DTYPE ) self.w_v = nn.Linear( self._d_model, self._d_v * self._n_heads, bias=False, device=self._device, dtype=_PYTORCH_DTYPE ) self.w_o = nn.Linear( self._n_heads * self._d_v, self._d_model, bias=False, device=self._device, dtype=_PYTORCH_DTYPE ) self.resid_dropout = nn.Dropout(dropout_p)
[docs] def forward(self, X: torch.Tensor): # X : (B, T, M) B, T, M = X.size() # ordering is very important q = ( self.w_q(X).view(B, T, self._n_heads, self._d_k).transpose(1, 2) ) # (B, T, M) x (M, K*H) -> (B, T, K*H) -> (B, T, H, K) -> (B, H, T, K) k = ( self.w_k(X).view(B, T, self._n_heads, self._d_k).transpose(1, 2) ) # (B, T, M) x (M, K*H) -> (B, T, K*H) -> (B, T, H, K) -> (B, H, T, K) v = ( self.w_v(X).view(B, T, self._n_heads, self._d_v).transpose(1, 2) ) # (B, T, M) x (M, V*H) -> (B, T, V*H) -> (B, T, H, V) -> (B, H, T, V) # print(X.shape, q.shape, k.shape, v.shape) # scaled dot product # (B, H, T, K) x (B, H, K, T) := (B, H, T, T) scaled_dot_product = (q @ k.transpose(-2, -1)) / torch.sqrt( torch.tensor(self._d_k, dtype=_PYTORCH_DTYPE, device=self._device) ) # print(scaled_dot_product.shape) if self._masked: # instead of regenerating everytime it can be regitered in the process # and can be picked like mask[:, :, :T, :T] # more efficient , might look at it in the future mask = ( torch.tril(torch.ones(T, T, device=self._device)).unsqueeze(0).unsqueeze(0) ) # (T, T) -> (1, T, T) -> (1, 1, T, T) scaled_dot_product = scaled_dot_product.masked_fill( mask == 0, float("-inf") ) scaled_dot_product_probs = torch.softmax( scaled_dot_product, dim=-1, dtype=_PYTORCH_DTYPE ) # (B, H, T, T) # print("Scaled Dot Product Probabilities ", scaled_dot_product_probs.shape, scaled_dot_product_probs.sum(dim=-1)) attn_out = ( scaled_dot_product_probs @ v ) # (B, H, T, T) x (B, H, T, V) = (B, H, T, V) attn_out = attn_out.permute(0, 2, 1, 3) # (B, H, T, V) -> (B, T, H, V) B, T, H, V = attn_out.shape out_projection = self.w_o( attn_out.reshape(B, T, H * V) ) # (B, T, H * V) x (H * V, M)= (B, T, M) out_projection = self.resid_dropout(out_projection) return out_projection
[docs] class PositionalEmbedding(nn.Module): pass
[docs] class PositionalEncoding(nn.Module): def __init__(self, d_model, context_len=10_000, scale=10_000, device="cpu"): super().__init__() self.scale = scale self._device = device self._context_len = context_len self._d_model = d_model self.pe = self._positional_encoding() def _positional_encoding(self): """ * Rows - Positions (sentence length, number of tokens in input sentence) * Columns - Dimensions (Dimensions of embedding or models) """ p = torch.zeros((self._context_len, self._d_model), device=self._device, dtype=_PYTORCH_DTYPE) positions = torch.arange(self._context_len, dtype=_PYTORCH_DTYPE).unsqueeze(1) denominator = 1 / torch.pow( self.scale, torch.arange(0, self._d_model, 2).unsqueeze(0) / self._d_model ) if self._d_model % 2 == 0: end_idx = denominator.shape[1] else: end_idx = denominator.shape[1] - 1 # for even indexes p[:, 0::2] = torch.sin(positions * denominator) # for odd indexes p[:, 1::2] = torch.cos(positions * denominator[:, :end_idx]) return p
[docs] def forward(self, X): _, row, col = X.shape # (B, T, M) return X + self.pe[:row, :col]
[docs] class FFN(nn.Module): def __init__(self, in_units, out_units, dropout_p=0.2, device="cpu"): super().__init__() self.linear1 = torch.nn.Linear(in_units, in_units * 4, bias=True, device=device, dtype=_PYTORCH_DTYPE) self.relu = torch.nn.ReLU().to(device) self.linear2 = torch.nn.Linear( in_units * 4, out_units, bias=True, device=device, dtype=_PYTORCH_DTYPE ) self.dropout = nn.Dropout(dropout_p)
[docs] def forward(self, X): # X = (B, T, M) X = self.linear1(X) X = self.relu(X) X = self.linear2(X) X = self.dropout(X) return X
[docs] class RepeatBlock(nn.Module): def __init__(self, n_heads, d_model, d_key, d_value, device, dropout_p): super().__init__() self._n_heads = n_heads self._d_model = d_model self._d_key = d_key self._d_value = d_value self._device = device self._masked_multi_head_attn = BatchMultiHeadAttentionV2( self._n_heads, self._d_model, self._d_value, self._d_key, masked=True, device=self._device, dropout_p=dropout_p ) self._multi_head_attn = BatchMultiHeadAttentionV2( self._n_heads, self._d_model, self._d_value, self._d_key, masked=False, device=self._device, dropout_p=dropout_p ) self._layer_norm1 = torch.nn.LayerNorm(self._d_model, device=self._device, dtype=_PYTORCH_DTYPE) self._layer_norm2 = torch.nn.LayerNorm(self._d_model, device=self._device, dtype=_PYTORCH_DTYPE) self._layer_norm3 = torch.nn.LayerNorm(self._d_model, device=self._device, dtype=_PYTORCH_DTYPE) self._feed_forward = FFN( in_units=self._d_model, out_units=self._d_model, device=self._device, dropout_p=dropout_p )
[docs] def forward(self, X): # X: (B, T, M) # changing implementation from paper "Attention is all you need" # adding layer norm before attention instead of after # as it optimizes the convergence according to some of the articles X = self._layer_norm1(X) # (B, T, M) X_masked_attn_out = self._masked_multi_head_attn(X) # (B, T, M) X = X + X_masked_attn_out # (B, T, M) X = self._layer_norm2(X) # (B, T, M) X_attn_out = self._multi_head_attn(X) # (B, T, M) X = X + X_attn_out # (B, T, M) X = self._layer_norm3(X) # (B, T, M) X_ffn_out = self._feed_forward(X) # (B, T, M) X = X + X_ffn_out # (B, T, M) return X
[docs] class LLM(nn.Module): def __init__(self, n_heads, d_model, d_key, d_value, n_x, vocab_size, dropout_p, device="cpu"): super().__init__() self._device = device self._n_heads = n_heads self._d_model = d_model self._d_key = d_key self._d_value = d_value self._n_x = n_x self._vocab_size = vocab_size self._pe = PositionalEncoding(d_model=self._d_model, device=self._device) self._linear = torch.nn.Linear( self._d_model, self._vocab_size, bias=False, device=self._device, dtype=_PYTORCH_DTYPE ) # in pytorch to create repeat list for iteratiion # Module list must be used self._repeat_blocks = nn.ModuleList([RepeatBlock( self._n_heads, self._d_model, self._d_key, self._d_value, device=self._device, dropout_p=dropout_p ) for _ in range(n_x)])
[docs] def forward(self, X: torch.Tensor): X = self._pe.forward(X) # (B, T, M) # print(X.shape) # Repeat is an unncessary step as Multihead Attention Layer must handle this inherently # IT WILL CAUSE A BIG PROBLEM IN LATER STAGES # Where we want last layer's output to be a single token, (this will give token proabability from each head :-( ) # X = X.repeat(self._n_heads, 1, 1) # (H, T, M) # # print(X.shape) for repeat_block in self._repeat_blocks: X = repeat_block(X) # (B, T, M) X = self._linear(X) # (B, T, V) # we need logits so removing softmax # X = torch.softmax(X, dim=0) return X[:, [-1], :]
@property def total_params(self) -> int: overall_params = 0 for param in self.parameters(): params_count = param.numel() overall_params += params_count return overall_params
[docs] def train(data_loader, llm_model: LLM, emb_model, loss_fn, optimizer, epochs, device): for _ in range(epochs): running_loss = 0.0 for X_idx, y_idx in tqdm(data_loader): # print(X_idx.shape, y_idx.shape) (B, T), (B, 1) input_embeddings = emb_model.embedding(X_idx).to(device) # (B, T, M) # print(input_embeddings.shape) pred_logits = llm_model.forward(input_embeddings) # (B, T, C) -> (B, 1, C) # (B, T, C) -> (B, C, T) pred_logits = pred_logits.permute(0, 2, 1) # (B, C, 1) # B, C, T = pred_logits.shape # y_idx = y_idx.repeat(B, T) # print(pred_logits.shape, y_idx.shape) loss = loss_fn(pred_logits, y_idx) # Compute loss (B, C, 1) , (B, 1) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() # print(loss.item()) print(f"Epoch Loss: {running_loss:.6f}")
[docs] @torch.no_grad() def generate( llm_model: LLM, emb_model: Word2Vec, tokenizer: PyBytePairTokenizer, context, max_tokens, top_k, temperature, device="cpu", ): idxs = torch.tensor(tokenizer.encode(context), dtype=torch.int64).to(device) for _ in range(max_tokens): embeddings = emb_model.embedding(idxs).to(device) # (T, M) # print(embeddings.shape) logits = llm_model.forward(embeddings.unsqueeze(0)) # (B, T, M) := (1, T, M) # as it is an autoregressive model, we need to get the last token's logits final_tokens_logits = logits[ :, -1, : ] # Last token's logits from the last layer = (1, M) # print(final_tokens_logits.shape) # # top k sampling # # torch.topk returns top k values sorted and their indices for each head top_values, _ = torch.topk(final_tokens_logits, top_k, dim=-1) # print(final_tokens_logits.shape, top_values.shape, top_indices.shape) least_values = top_values[:, [-1]] # print(top_values) final_tokens_logits[final_tokens_logits < least_values] = float( "-inf" ) # (1, M) # print(final_tokens_logits) # Apply temperature scaling # higher the temperature, more scaled down the logits and more random the output final_probs = torch.softmax(final_tokens_logits / temperature, dim=-1) # (1, M) next_token = torch.multinomial(final_probs, num_samples=1) # (1, 1) # print(idxs.shape, next_token.shape) # Append next token to sequence idxs = torch.cat([idxs, next_token.view(-1)], dim=0) return tokenizer.decode(idxs.cpu().tolist()) # tokenizer.decode(idxs.cpu().numpy())
if __name__ == "__main__": B, T, M = 32, 10, 100 X = torch.rand(B, T, M) print(X.shape) attn1 = BatchMultiHeadAttention(6, M, 6, 6) X1 = attn1.forward(X) print(X1.shape) attn2 = BatchMultiHeadAttentionV2(6, M, 10, 10) X2 = attn2.forward(X) print(X2.shape)