diff --git a/src/config/base.py b/src/config/base.py index e48c277..bcb0d41 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -27,7 +27,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument('--results_base_folder', default="./exps", type=str) parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT # Dataset params - parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) + parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'openwebtext2']) parser.add_argument('--vocab_size', default=50304, type=int) parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, mostly useless except for openwebtext2 # Model params diff --git a/src/data/arxiv.py b/src/data/arxiv.py deleted file mode 100644 index bd146f1..0000000 --- a/src/data/arxiv.py +++ /dev/null @@ -1,115 +0,0 @@ -import os -import tarfile -import logging -from pathlib import Path -from typing import Optional -from multiprocessing import Pool -from tempfile import NamedTemporaryFile -from subprocess import Popen, TimeoutExpired, PIPE -from typing import Tuple, List - -import numpy as np -import requests -from tqdm.auto import tqdm -import tiktoken - - -def convert_to_markdown(args: Tuple[Path, Path]): - texfile, mdroot = args - mdfile = mdroot/f"{texfile.name}.md" - with Popen(["pandoc", "--wrap=none", "--from", "latex", texfile, - "--output", mdfile], stderr=PIPE) as proc: - try: - proc.communicate(timeout=1) - except TimeoutExpired: - proc.kill() - - - -def fetch_arxiv(root: Path, year: int): - # download latex - url = f"https://www.cs.cornell.edu/projects/kddcup/download/hep-th-{year}.tar.gz" - texroot = root/"tex" - print("Downloading Arxiv year", year) - req = requests.get(url, timeout=60) - with NamedTemporaryFile(suffix=".tar.gz") as f: - f.write(req.content) - logging.debug("Tar saved in tempfile %s" % f.name) - with tarfile.open(f.name) as tar: - logging.debug("Extracting tarfile") - tar.extractall(texroot) - - # convert to markdown - mdroot = root/"md"/str(year) - mdroot.mkdir(parents=True) - files = list((texroot/str(year)).iterdir()) - with Pool(os.cpu_count()) as p: - args = [(texfile, mdroot) for texfile in files] - for _ in tqdm(p.imap_unordered(convert_to_markdown, args), - desc="Converting to markdown", total=len(files)): - pass - - -def tokenize_arxiv(root: Path, year: int): - tokenizer = tiktoken.get_encoding("gpt2") - tokens = [] - tokens_val = [] - tokens_test = [] - mds = root/"md"/str(year) - - # tokenize - desc = f"Tokenizing {year}" - for i, mdpath in enumerate(tqdm(list(mds.iterdir()), desc=desc)): - with open(mdpath, encoding="utf8") as f: - text = "".join(f.readlines()) - if i % 10 <= 6: # train split - tokens += tokenizer.encode(text) - elif i % 10 <= 8: # val split - tokens_val += tokenizer.encode(text) - else: # test split - tokens_test += tokenizer.encode(text) - - # save to dir - tpath = root/str(year) - tpath.mkdir(parents=True) - for x, name in zip([tokens, tokens_val, tokens_test], - ["train", "val", "test"]): - mem = np.memmap(tpath/f"{name}.npy", dtype=np.uint16, mode="w+", - shape=len(x)) - for i, v in enumerate(x): - mem[i] = v - - -def load_arxiv(cachedir: Path, years: Optional[List[int]] = None): - all_years = list(range(1992, 2004)) - if years is None: - years = all_years - assert set(years) <= set(all_years) - root = cachedir/"arxiv" - root.mkdir(exist_ok=True, parents=True) - - # download all years requested that are not present - for year in years: - if not (root/"md"/str(year)).exists(): - fetch_arxiv(root, year) - - # tokenize all years not previously tokenized - for year in years: - if not (root/str(year)).exists(): - tokenize_arxiv(root, year) - - # load meta - ret = {} - for split in ["train", "val"]: - paths = [root/str(year)/f"{split}.npy" for year in years] - x = [np.memmap(path, dtype=np.uint16, mode="r") for path in paths] - ret[split] = np.concatenate(x) - return ret - - -def get_arxiv_2000(): - return load_arxiv(Path(os.path.dirname(__file__))/"datasets", [2000]) - - -def get_arxiv_full(): - return load_arxiv(Path(os.path.dirname(__file__))/"datasets") diff --git a/src/data/openwebtext2.py b/src/data/openwebtext2.py index eef9d50..07cb0ca 100644 --- a/src/data/openwebtext2.py +++ b/src/data/openwebtext2.py @@ -51,8 +51,5 @@ def process(example): idx += len(arr_batch) arr.flush() - train_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') - val_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') - - return {'train': train_data, 'val': val_data} + return {'train': os.path.join(OWT2_DATA_PATH, 'train.bin'), 'val': os.path.join(OWT2_DATA_PATH, 'val.bin')} diff --git a/src/data/shakespeare.py b/src/data/shakespeare.py index ab6e022..b73d951 100644 --- a/src/data/shakespeare.py +++ b/src/data/shakespeare.py @@ -47,5 +47,4 @@ def get_shakespeare_data(): mem[:] = x_test # at this point we know that the binfile was properly created so we load it - return {"train": np.memmap(train_path, dtype=np.uint16, mode="r"), - "val": np.memmap(test_path, dtype=np.uint16, mode="r")} + return {"train": train_path, "val": test_path} diff --git a/src/data/slimpajama.py b/src/data/slimpajama.py index c3960d7..198762a 100644 --- a/src/data/slimpajama.py +++ b/src/data/slimpajama.py @@ -6,7 +6,6 @@ SPJ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slimpajama6B/") -SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1") tknzr = tiktoken.get_encoding("gpt2") @@ -60,69 +59,7 @@ def process(example): idx += len(arr_batch) arr.flush() - train_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" - ) - val_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" - ) - - return {"train": train_data, "val": val_data} - - -def get_slimpajama_chunk1(num_proc=40): - if not os.path.exists(os.path.join(SPJ_CHUNK_1_DATA_PATH, "train.bin")): - os.makedirs(SPJ_DATA_PATH, exist_ok=True) - dataset = load_dataset("cerebras/SlimPajama-627B", split="train/chunk1") - - split_dataset = dataset["train"].train_test_split( - test_size=0.0005, seed=2357, shuffle=True - ) - split_dataset["val"] = split_dataset.pop("test") - - def process(example): - ids = tknzr.encode_ordinary( - example["text"] - ) # encode_ordinary ignores any special tokens - ids.append( - tknzr.eot_token - ) # add the end of text token, e.g. 50256 for gpt2 bpe - out = {"ids": ids, "len": len(ids)} - return out - - # tokenize the dataset - tokenized = split_dataset.map( - process, - remove_columns=["text"], - desc="tokenizing the splits", - num_proc=num_proc, - ) - - # concatenate all the ids in each dataset into one large file we can use for training - for split, dset in tokenized.items(): - arr_len = np.sum(dset["len"]) - filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin") - dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) - arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) - total_batches = min(1024, len(dset)) - - idx = 0 - for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): - # Batch together samples for faster write - batch = dset.shard( - num_shards=total_batches, index=batch_idx, contiguous=True - ).with_format("numpy") - arr_batch = np.concatenate(batch["ids"]) - # Write into mmap - arr[idx : idx + len(arr_batch)] = arr_batch - idx += len(arr_batch) - arr.flush() - - train_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" - ) - val_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" - ) - - return {"train": train_data, "val": val_data} + return { + "train": os.path.join(SPJ_DATA_PATH, "train.bin"), + "val": os.path.join(SPJ_DATA_PATH, "val.bin"), + } diff --git a/src/data/utils.py b/src/data/utils.py index b5c7bb8..4e5a976 100755 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -4,56 +4,47 @@ from .shakespeare import get_shakespeare_data from .wikitext import get_wikitext_data -from .arxiv import get_arxiv_2000, get_arxiv_full from .openwebtext2 import get_openwebtext2_data from .slimpajama import get_slimpajama_data def get_dataset(args) -> Dict[str, np.ndarray]: - """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is - contained in its own python file. The expected format at the moment is a dictionary of np.memmap - containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """ - if args.dataset == 'wikitext': + """Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is + contained in its own python file. The expected format at the moment is a dictionary of np.memmap + containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. + This just returns a dictionary of the paths to the np.memmap objects, and does not load the data into memory. + """ + if args.dataset == "wikitext": return get_wikitext_data() if args.dataset == "shakespeare-char": return get_shakespeare_data() - if args.dataset == "arxiv2000": - return get_arxiv_2000() - if args.dataset == "arxiv": - return get_arxiv_full() - if args.dataset == "arxiv+wiki": - arxiv_data = get_arxiv_full() - wiki_data = get_wikitext_data() - train_data = np.concatenate((arxiv_data['train'], wiki_data['train'])) - val_data = np.concatenate((arxiv_data['val'], wiki_data['val'])) - return {'train': train_data, 'val': val_data} - if args.dataset == 'openwebtext2': + if args.dataset == "openwebtext2": return get_openwebtext2_data() if args.dataset == "slimpajama": return get_slimpajama_data() else: raise NotImplementedError(f"Unknow dataset key '{args.dataset}'") + class Dataset(torch.utils.data.Dataset): - def __init__(self, data, sequence_length): + def __init__(self, data_path, sequence_length): super().__init__() - self.data = data + self.data_path = data_path self.sequence_length = sequence_length def __len__(self): - total_length = len(self.data) + data = np.memmap(self.data_path, dtype=np.uint16, mode="r") + total_length = len(data) # chunk the data into sequences of length `sequence_length` - # NOTE: we discard the last remainding sequence if it's not of length `sequence_length` + # NOTE: we discard the last remaining sequence if it's not of length `sequence_length` return (total_length - 1) // self.sequence_length def __getitem__(self, idx): + data = np.memmap(self.data_path, dtype=np.uint16, mode="r") seq_length = self.sequence_length idx = idx * seq_length - x = torch.from_numpy((self.data[idx : idx + seq_length]).astype(np.int64)) - - y = torch.from_numpy( - (self.data[idx + 1 : idx + 1 + seq_length]).astype(np.int64) - ) + x = torch.from_numpy((data[idx : idx + seq_length]).astype(np.int64)) + y = torch.from_numpy((data[idx + 1 : idx + 1 + seq_length]).astype(np.int64)) return x, y diff --git a/src/data/wikitext.py b/src/data/wikitext.py index 646f636..e1ce754 100755 --- a/src/data/wikitext.py +++ b/src/data/wikitext.py @@ -36,7 +36,4 @@ def get_wikitext_data(): eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'val.bin')) print("completed the tokenization process!") - train_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') - val_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') - - return {'train': train_data, 'val': val_data} + return {'train': os.path.join(WIKITEXT_DATA_PATH, 'train.bin'), 'val': os.path.join(WIKITEXT_DATA_PATH, 'val.bin')} diff --git a/src/main.py b/src/main.py index 92ed664..26dea61 100755 --- a/src/main.py +++ b/src/main.py @@ -48,8 +48,8 @@ def main(args): if args.data_in_ram: data = {'train': np.array(data['train']), 'val': np.array(data['val'])} - print(f"Num training tokens: {len(data['train'])}") - print(f"Num validation tokens: {len(data['val'])}") + print(f"Num training tokens: {len(np.memmap(data['train'], dtype=np.uint16, mode='r'))}") + print(f"Num validation tokens: {len(np.memmap(data['val'], dtype=np.uint16, mode='r'))}") model = get_model(args).to(args.device) # todo: take care of initializing the model if args.use_pretrained != 'none' diff --git a/src/optim/base.py b/src/optim/base.py index 241f508..ea24fd7 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -35,7 +35,9 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba ) num_substeps_per_epoch = len(data["train"]) - train_epochs = substep//num_substeps_per_epoch + num_val_batches = len(data["val"]) + + train_epochs = substep // num_substeps_per_epoch if rng_state_dict is not None and rng_state_dict.get("train_sampler_state", None) is not None: train_sampler.generator.set_state(rng_state_dict["train_sampler_state"]) @@ -45,10 +47,6 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba sampler_state_before_iter = train_sampler.generator.get_state() data_train_iter = iter(data["train"]) - - # for val data we don't care about epochs? just cycle through (no need to set_epoch to reshuffle) - data_val_iter = itertools.cycle(data["val"]) - stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} @@ -69,7 +67,7 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba for _ in range(substep % num_substeps_per_epoch): get_batch(data_train_iter, device=extra_args.device) - + while itr < iterations: for microstep_idx in range(acc_steps): # gradient accumulation @@ -110,15 +108,16 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba model.eval() train_loss = loss.detach().cpu().item() * acc_steps current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr - - eval_steps = ( - 24 if itr < iterations else len(data["val"]) - ) + # set epoch and restart val loader as to always have same eval batches + if hasattr(val_sampler, "set_epoch"): + val_sampler.set_epoch(0) + data_val_iter = iter(data["val"]) + val_acc, val_loss, val_perplexity = eval( model, data_val_iter, extra_args.device, - max_num_batches=eval_steps, + max_num_batches=32, ctx=type_ctx, ) @@ -133,15 +132,26 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba "iter": itr, "train/loss": train_loss, "val/loss": val_loss, - "val/perplexity": val_perplexity, "val/acc": val_acc, + "val/perplexity": val_perplexity, "lr": current_lr, } if itr == iterations: - logs["val/final-ppl"] = val_perplexity - logs["val/final-acc"] = val_acc - logs["val/final-loss"] = val_loss + # set epoch and restart val loader as to always have same eval batches (iter from beginning) + if hasattr(val_sampler, "set_epoch"): + val_sampler.set_epoch(0) + data_val_iter = iter(data["val"]) + final_val_acc, final_val_loss, final_val_perplexity = eval( + model, + data_val_iter, + extra_args.device, + max_num_batches=num_val_batches, + ctx=type_ctx, + ) + logs["val/final-ppl"] = final_val_perplexity + logs["val/final-acc"] = final_val_acc + logs["val/final-loss"] = final_val_loss wandb.log(logs)