Source code for mightypy.nlp.dataset
"""
Dataset
--------
"""
import torch
from torch.utils.data import Dataset
from tokkit import data_loader
from mightypy.datautils.download import FileDownloader
[docs]
class CustomDataset(Dataset):
def __init__(
self, path, tokenizer, context_length=5, dataset_path="datasets", device="cpu"
):
self.context_length = context_length
self.dataset_path = dataset_path
self.device = device
self.downloader = FileDownloader(self.dataset_path)
self._load(path, tokenizer)
def _load(self, path, tokenizer):
if path.startswith("http"):
path = self.downloader.save_file_local(path)
self.vocab_size = tokenizer.size
self.raw_data = data_loader(path)
self.tokens = tokenizer.encode_corpus(self.raw_data)
self.tokens_tensor = torch.tensor(self.tokens, dtype=torch.long)
def __len__(self):
return max(0, len(self.tokens) - self.context_length - 1)
def __getitem__(self, idx):
return (
torch.tensor(self.tokens[idx : idx + self.context_length], dtype=torch.long).to(self.device),
torch.tensor([self.tokens[idx + self.context_length]], dtype=torch.long).to(self.device),
)
if __name__ == "__main__":
from tokkit import PyBytePairTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm
tokenizer = PyBytePairTokenizer()
url = "https://raw.githubusercontent.com/NishantBaheti/tokkit/refs/heads/main/datasets/raw/combined.txt"
dataset = CustomDataset(url, tokenizer, context_length=100)
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
for X_batch, y_batch in tqdm(dataloader):
print(X_batch, y_batch)
for x, y in zip(X_batch, y_batch):
print(tokenizer.decode(x), tokenizer.decode(y))
break