|
|
import os |
|
|
import sys |
|
|
with open(sys.argv[0]) as f: |
|
|
code = f.read() |
|
|
import uuid |
|
|
import glob |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
import torch._inductor.config as config |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def zeropower_via_svd(G, steps=None): |
|
|
U, S, V = G.svd() |
|
|
return U @ V.T |
|
|
|
|
|
@torch.compile |
|
|
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): |
|
|
""" |
|
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a |
|
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose |
|
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at |
|
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere |
|
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T |
|
|
where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model |
|
|
performance at all relative to UV^T, where USV^T = G is the SVD. |
|
|
""" |
|
|
assert len(G.shape) == 2 |
|
|
a, b, c = (3.4445, -4.7750, 2.0315) |
|
|
X = G.bfloat16() |
|
|
X /= (X.norm() + eps) |
|
|
if G.size(0) > G.size(1): |
|
|
X = X.T |
|
|
for _ in range(steps): |
|
|
A = X @ X.T |
|
|
B = A @ X |
|
|
X = a * X + b * B + c * A @ B |
|
|
if G.size(0) > G.size(1): |
|
|
X = X.T |
|
|
return X |
|
|
|
|
|
zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) |
|
|
|
|
|
class Muon(torch.optim.Optimizer): |
|
|
""" |
|
|
Muon - MomentUm Orthogonalized by Newton-schulz |
|
|
|
|
|
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- |
|
|
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal |
|
|
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has |
|
|
the advantage that it can be stably run in bfloat16 on the GPU. |
|
|
|
|
|
Some warnings: |
|
|
- This optimizer assumes that all parameters passed in are 2D. |
|
|
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D |
|
|
parameters; those should all be optimized by a standard method (e.g., AdamW). |
|
|
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. |
|
|
- We believe it is unlikely to work well for training with small batch size. |
|
|
- We believe it may not work well for finetuning pretrained models, but we haven't tested this. |
|
|
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). |
|
|
|
|
|
Arguments: |
|
|
lr: The learning rate used by the internal SGD. |
|
|
momentum: The momentum used by the internal SGD. |
|
|
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) |
|
|
backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') |
|
|
backend_steps: The number of iteration steps to use in the backend, if it is iterative. |
|
|
""" |
|
|
def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5): |
|
|
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) |
|
|
super().__init__(params, defaults) |
|
|
|
|
|
def step(self): |
|
|
for group in self.param_groups: |
|
|
lr = group['lr'] |
|
|
momentum = group['momentum'] |
|
|
zeropower_backend = zeropower_backends[group['backend']] |
|
|
for p in group['params']: |
|
|
g = p.grad |
|
|
if g is None: |
|
|
continue |
|
|
state = self.state[p] |
|
|
if 'momentum_buffer' not in state: |
|
|
state['momentum_buffer'] = torch.zeros_like(g) |
|
|
buf = state['momentum_buffer'] |
|
|
buf.mul_(momentum).add_(g) |
|
|
if group['nesterov']: |
|
|
g = g.add(buf, alpha=momentum) |
|
|
if g.size(0) == 3 * g.size(1): |
|
|
g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))]) |
|
|
scale = g.size(1)**0.5 |
|
|
else: |
|
|
g = zeropower_backend(g, steps=group['backend_steps']) |
|
|
scale = max(g.size(0), g.size(1))**0.5 |
|
|
p.data.add_(g, alpha=-lr * scale) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Rotary(torch.nn.Module): |
|
|
|
|
|
def __init__(self, dim, base=10000): |
|
|
super().__init__() |
|
|
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.seq_len_cached = None |
|
|
self.cos_cached = None |
|
|
self.sin_cached = None |
|
|
|
|
|
def forward(self, x): |
|
|
seq_len = x.shape[1] |
|
|
if seq_len != self.seq_len_cached: |
|
|
self.seq_len_cached = seq_len |
|
|
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
|
|
freqs = torch.outer(t, self.inv_freq).to(x.device) |
|
|
self.cos_cached = freqs.cos().bfloat16() |
|
|
self.sin_cached = freqs.sin().bfloat16() |
|
|
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] |
|
|
|
|
|
def apply_rotary_emb(x, cos, sin): |
|
|
assert x.ndim == 4 |
|
|
d = x.shape[3]//2 |
|
|
x1 = x[..., :d] |
|
|
x2 = x[..., d:] |
|
|
y1 = x1 * cos + x2 * sin |
|
|
y2 = x1 * (-sin) + x2 * cos |
|
|
return torch.cat([y1, y2], 3).type_as(x) |
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.n_head = config.n_head |
|
|
self.n_embd = config.n_embd |
|
|
self.head_dim = self.n_embd // self.n_head |
|
|
assert self.n_embd % self.n_head == 0 |
|
|
self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
|
|
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
self.c_proj.weight.data.zero_() |
|
|
self.rotary = Rotary(self.head_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim) |
|
|
k = self.c_k(x).view(B, T, self.n_head, self.head_dim) |
|
|
v = self.c_v(x).view(B, T, self.n_head, self.head_dim) |
|
|
cos, sin = self.rotary(q) |
|
|
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) |
|
|
q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) |
|
|
y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) |
|
|
y = y.transpose(1, 2).contiguous().view_as(x) |
|
|
y = self.c_proj(y) |
|
|
return y |
|
|
|
|
|
class MLP(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) |
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) |
|
|
self.c_proj.weight.data.zero_() |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.c_fc(x) |
|
|
x = F.relu(x).square() |
|
|
x = self.c_proj(x) |
|
|
return x |
|
|
|
|
|
class Block(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.attn = CausalSelfAttention(config) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.attn(F.rms_norm(x, (x.size(-1),))) |
|
|
x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GPTConfig: |
|
|
vocab_size : int = 50304 |
|
|
n_layer : int = 12 |
|
|
n_head : int = 6 |
|
|
n_embd : int = 768 |
|
|
|
|
|
class GPT(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.transformer = nn.ModuleDict(dict( |
|
|
wte = nn.Embedding(config.vocab_size, config.n_embd), |
|
|
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
|
|
)) |
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
self.transformer.wte.weight = self.lm_head.weight |
|
|
|
|
|
def forward(self, idx, targets=None, return_logits=True): |
|
|
|
|
|
|
|
|
x = self.transformer.wte(idx) |
|
|
for block in self.transformer.h: |
|
|
x = block(x) |
|
|
x = F.rms_norm(x, (x.size(-1),)) |
|
|
|
|
|
if targets is not None: |
|
|
|
|
|
logits = self.lm_head(x) |
|
|
logits = logits.float() |
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
|
|
else: |
|
|
|
|
|
logits = self.lm_head(x[:, [-1], :]) |
|
|
logits = logits.float() |
|
|
loss = None |
|
|
|
|
|
|
|
|
if not return_logits: |
|
|
logits = None |
|
|
|
|
|
return logits, loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _peek_data_shard(filename): |
|
|
|
|
|
with open(filename, "rb") as f: |
|
|
|
|
|
header = np.frombuffer(f.read(256*4), dtype=np.int32) |
|
|
if header[0] != 20240520: |
|
|
print("ERROR: magic number mismatch in the data .bin file!") |
|
|
print("---> HINT: Are you passing in a correct file with --input_bin?") |
|
|
print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") |
|
|
print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") |
|
|
exit(1) |
|
|
assert header[1] == 1, "unsupported version" |
|
|
ntok = header[2] |
|
|
return ntok |
|
|
|
|
|
def _load_data_shard(filename): |
|
|
with open(filename, "rb") as f: |
|
|
|
|
|
header = np.frombuffer(f.read(256*4), dtype=np.int32) |
|
|
assert header[0] == 20240520, "magic number mismatch in the data .bin file" |
|
|
assert header[1] == 1, "unsupported version" |
|
|
ntok = header[2] |
|
|
|
|
|
tokens = np.frombuffer(f.read(), dtype=np.uint16) |
|
|
assert len(tokens) == ntok, "number of tokens read does not match header?" |
|
|
return tokens |
|
|
|
|
|
class DistributedDataLoader: |
|
|
def __init__(self, filename_pattern, B, T, process_rank, num_processes): |
|
|
self.process_rank = process_rank |
|
|
self.num_processes = num_processes |
|
|
self.B = B |
|
|
self.T = T |
|
|
|
|
|
|
|
|
self.files = sorted(glob.glob(filename_pattern)) |
|
|
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" |
|
|
|
|
|
|
|
|
ntok_total = 0 |
|
|
for fname in self.files: |
|
|
shard_ntok = _peek_data_shard(fname) |
|
|
assert shard_ntok >= num_processes * B * T + 1 |
|
|
ntok_total += int(shard_ntok) |
|
|
self.ntok_total = ntok_total |
|
|
|
|
|
|
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.current_shard = 0 |
|
|
self.current_position = self.process_rank * self.B * self.T |
|
|
self.tokens = _load_data_shard(self.files[self.current_shard]) |
|
|
|
|
|
def advance(self): |
|
|
self.current_shard = (self.current_shard + 1) % len(self.files) |
|
|
self.current_position = self.process_rank * self.B * self.T |
|
|
self.tokens = _load_data_shard(self.files[self.current_shard]) |
|
|
|
|
|
def next_batch(self): |
|
|
B = self.B |
|
|
T = self.T |
|
|
buf = self.tokens[self.current_position : self.current_position+B*T+1] |
|
|
buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) |
|
|
x = (buf[:-1]).view(B, T) |
|
|
y = (buf[1:]).view(B, T) |
|
|
|
|
|
self.current_position += B * T * self.num_processes |
|
|
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): |
|
|
self.advance() |
|
|
return x.cuda(), y.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Hyperparameters: |
|
|
|
|
|
input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' |
|
|
input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' |
|
|
|
|
|
batch_size : int = 8*64 |
|
|
device_batch_size : int = 64 |
|
|
sequence_length : int = 1024 |
|
|
num_iterations : int = 5100 |
|
|
learning_rate : float = 0.0036 |
|
|
warmup_iters : int = 0 |
|
|
warmdown_iters : int = 1450 |
|
|
weight_decay : float = 0 |
|
|
|
|
|
val_loss_every : int = 125 |
|
|
val_tokens : int = 10485760 |
|
|
save_every : int = 0 |
|
|
args = Hyperparameters() |
|
|
|
|
|
|
|
|
assert torch.cuda.is_available() |
|
|
dist.init_process_group(backend='nccl') |
|
|
ddp_rank = int(os.environ['RANK']) |
|
|
ddp_local_rank = int(os.environ['LOCAL_RANK']) |
|
|
ddp_world_size = int(os.environ['WORLD_SIZE']) |
|
|
device = f'cuda:{ddp_local_rank}' |
|
|
torch.cuda.set_device(device) |
|
|
print(f"using device: {device}") |
|
|
master_process = (ddp_rank == 0) |
|
|
|
|
|
|
|
|
B, T = args.device_batch_size, args.sequence_length |
|
|
|
|
|
assert args.val_tokens % (B * T * ddp_world_size) == 0 |
|
|
val_steps = args.val_tokens // (B * T * ddp_world_size) |
|
|
|
|
|
assert args.batch_size % (B * ddp_world_size) == 0 |
|
|
train_accumulation_steps = args.batch_size // (B * ddp_world_size) |
|
|
|
|
|
|
|
|
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) |
|
|
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) |
|
|
if master_process: |
|
|
print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") |
|
|
print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") |
|
|
x, y = train_loader.next_batch() |
|
|
|
|
|
|
|
|
|
|
|
num_vocab = 50304 |
|
|
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768)) |
|
|
model = model.cuda() |
|
|
if hasattr(config, "coordinate_descent_tuning"): |
|
|
config.coordinate_descent_tuning = True |
|
|
model = torch.compile(model) |
|
|
|
|
|
model = DDP(model, device_ids=[ddp_local_rank]) |
|
|
raw_model = model.module |
|
|
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95), |
|
|
weight_decay=args.weight_decay, fused=True) |
|
|
optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95) |
|
|
optimizers = [optimizer1, optimizer2] |
|
|
|
|
|
def get_lr(it): |
|
|
assert it <= args.num_iterations |
|
|
|
|
|
if it < args.warmup_iters: |
|
|
return (it+1) / args.warmup_iters |
|
|
|
|
|
elif it < args.num_iterations - args.warmdown_iters: |
|
|
return 1.0 |
|
|
|
|
|
else: |
|
|
decay_ratio = (args.num_iterations - it) / args.warmdown_iters |
|
|
return decay_ratio |
|
|
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] |
|
|
|
|
|
|
|
|
if master_process: |
|
|
run_id = str(uuid.uuid4()) |
|
|
logdir = 'logs/%s/' % run_id |
|
|
os.makedirs(logdir, exist_ok=True) |
|
|
logfile = 'logs/%s.txt' % run_id |
|
|
|
|
|
with open(logfile, "w") as f: |
|
|
|
|
|
f.write('='*100 + '\n') |
|
|
f.write(code) |
|
|
f.write('='*100 + '\n') |
|
|
|
|
|
|
|
|
f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n") |
|
|
import subprocess |
|
|
result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
|
f.write(f'{result.stdout}\n') |
|
|
f.write('='*100 + '\n') |
|
|
|
|
|
training_time_ms = 0 |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
t0 = time.time() |
|
|
|
|
|
train_loader.reset() |
|
|
for step in range(args.num_iterations + 1): |
|
|
last_step = (step == args.num_iterations) |
|
|
|
|
|
|
|
|
|
|
|
if step == 10: |
|
|
training_time_ms = 0 |
|
|
t0 = time.time() |
|
|
timed_steps = float('nan') if step <= 11 else (step - 10) + 1 |
|
|
|
|
|
|
|
|
if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
training_time_ms += 1000 * (time.time() - t0) |
|
|
|
|
|
model.eval() |
|
|
val_loader.reset() |
|
|
val_loss = 0.0 |
|
|
for _ in range(val_steps): |
|
|
x_val, y_val = val_loader.next_batch() |
|
|
with ctx: |
|
|
_, loss = model(x_val, y_val, return_logits=False) |
|
|
val_loss += loss.detach() |
|
|
del loss |
|
|
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) |
|
|
val_loss /= val_steps |
|
|
|
|
|
if master_process: |
|
|
print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') |
|
|
with open(logfile, "a") as f: |
|
|
f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n') |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
t0 = time.time() |
|
|
|
|
|
if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
training_time_ms += 1000 * (time.time() - t0) |
|
|
|
|
|
log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) |
|
|
torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
t0 = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if last_step: |
|
|
break |
|
|
|
|
|
|
|
|
model.train() |
|
|
for i in range(1, train_accumulation_steps+1): |
|
|
|
|
|
with ctx: |
|
|
_, loss = model(x, y, return_logits=False) |
|
|
train_loss = loss.detach() |
|
|
|
|
|
x, y = train_loader.next_batch() |
|
|
|
|
|
if i < train_accumulation_steps: |
|
|
with model.no_sync(): |
|
|
loss.backward() |
|
|
else: |
|
|
loss.backward() |
|
|
for p in model.parameters(): |
|
|
p.grad /= train_accumulation_steps |
|
|
|
|
|
for opt, sched in zip(optimizers, schedulers): |
|
|
opt.step() |
|
|
sched.step() |
|
|
|
|
|
model.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if master_process: |
|
|
approx_time = training_time_ms + 1000 * (time.time() - t0) |
|
|
print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") |
|
|
with open(logfile, "a") as f: |
|
|
f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n") |
|
|
|
|
|
if master_process: |
|
|
print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") |
|
|
|
|
|
|
|
|
|
|
|
dist.destroy_process_group() |
|
|
|