Created
December 31, 2024 10:03
-
-
Save lapp0/8553e911c649eea11cc2d7426f26eab6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
with open(sys.argv[0]) as f: | |
code = f.read() # read the code of this file ASAP, for logging | |
with open('optimizer.py', 'r', encoding='utf-8') as f: | |
source_code = f.read() | |
code += source_code | |
with open('model.py', 'r', encoding='utf-8') as f: | |
source_code = f.read() | |
code += source_code | |
with open('utils.py', 'r', encoding='utf-8') as f: | |
source_code = f.read() | |
code += source_code | |
with open('dataloading.py', 'r', encoding='utf-8') as f: | |
source_code = f.read() | |
code += source_code | |
import argparse | |
import uuid | |
import time | |
import contextlib | |
import math | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torch._inductor.config as config | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from pathlib import Path | |
from optimizer import Muon | |
from model import ModelConfig, ESM, CastedLinear | |
from dataloading import DistributedDataLoader, DistributedPaddedDataLoader | |
def get_args(): | |
parser = argparse.ArgumentParser(description='ESM2 training arguments') | |
# Model hyperparams | |
parser.add_argument('--vocab_size', type=int, default=33, help='vocabulary size') | |
parser.add_argument('--num_hidden_layers', type=int, default=12, help='number of transformer layers') | |
parser.add_argument('--num_attention_heads', type=int, default=6, help='number of attention heads (head dim 128 suggested by @Grad62304977)') | |
parser.add_argument('--hidden_size', type=int, default=768, help='model hidden dimension size') | |
# Data hyperparams | |
parser.add_argument('--input_bin', type=str, default='data/omgprot50/omgprot50_train_*.bin', help='input .bins to train on') | |
parser.add_argument('--input_valid_bin', type=str, default='data/omgprot50/omgprot50_valid_*.bin', help='input .bins to eval validation loss on') | |
parser.add_argument('--input_test_bin', type=str, default='data/omgprot50/omgprot50_test_*.bin', help='input .bins to eval test loss on') | |
# Optimization hyperparams | |
parser.add_argument('--batch_size', type=int, default=4*64*1024, help='batch size, in tokens, across all devices') | |
parser.add_argument('--grad_accum', type=int, default=1, help='manually set number of gradient accumulation steps, else, will be ddp_world_size') | |
parser.add_argument('--num_steps', type=int, default=20_000, help='number of iterations to run') | |
parser.add_argument('--warmup_steps', type=int, default=1000, help='number of warmup steps') | |
parser.add_argument('--cooldown_steps', type=int, default=1000, help='number of cooldown steps') | |
# Evaluation and logging hyperparams | |
parser.add_argument('--valid_loss_every', type=int, default=100, help='every how many steps to evaluate val loss? 0 for only at the end') | |
parser.add_argument('--hf_model_name', type=str, default='lapp0/esm2_speedrun', help='huggingface model name') | |
parser.add_argument('--token', type=str, default=None, help='huggingface token') | |
parser.add_argument('--save_every', type=int, default=1000, help='save every how many steps? None for no saving') | |
args = parser.parse_args() | |
return args | |
def get_param_count(model): | |
total_params = 0 | |
for _, param in model.named_parameters(): | |
total_params += param.numel() | |
return total_params | |
if __name__ == "__main__": | |
args = get_args() | |
if args.token: | |
from huggingface_hub import login | |
login(args.token) | |
args.token = None | |
model_config = ModelConfig( | |
vocab_size=args.vocab_size, | |
num_hidden_layers=args.num_hidden_layers, | |
num_attention_heads=args.num_attention_heads, | |
hidden_size=args.hidden_size, | |
) | |
# set up DDP (distributed data parallel) if available, otherwise single GPU | |
if 'RANK' in os.environ: | |
ddp_rank = int(os.environ['RANK']) | |
ddp_local_rank = int(os.environ['LOCAL_RANK']) | |
ddp_world_size = int(os.environ['WORLD_SIZE']) | |
device = torch.device(f'cuda:{ddp_local_rank}') | |
torch.cuda.set_device(device) | |
dist.init_process_group(backend='nccl', device_id=device) | |
dist.barrier() | |
master_process = (ddp_rank == 0) | |
else: | |
ddp_rank = 0 | |
ddp_local_rank = 0 | |
ddp_world_size = 1 | |
device = torch.device('cuda:0') | |
torch.cuda.set_device(device) | |
master_process = True | |
print(f'using device: {device}') | |
# begin logging | |
logfile = None | |
if master_process: | |
run_id = uuid.uuid4() | |
Path('logs').mkdir(exist_ok=True) | |
# logdir = Path('logs') / f'{run_id}' | |
# logdir.mkdir() | |
logfile = Path('logs') / f'{run_id}.txt' | |
print(logfile.stem) | |
# create the log file | |
with logfile.open('w') as f: | |
# begin the log by printing this file (the Python code) | |
print(code, file=f) | |
print('=' * 100, file=f) | |
def print0(s, logonly=False): | |
if master_process: | |
with logfile.open('a') as f: | |
if not logonly: | |
print(s) | |
print(s, file=f) | |
# log information about the hardware/software environment this is running on | |
# and print the full `nvidia-smi` to file | |
print0(f'Running python {sys.version}') | |
print0(f'Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:') | |
import subprocess | |
result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
print0(f'{result.stdout}', logonly=True) | |
print0('='*100, logonly=True) | |
print0(f'Model config: {model_config}') | |
print0(f'Args: {args.__dict__}') | |
# calculate the steps of gradient accumulation required to attain the desired global batch size | |
# args.batch_size should refer to the total amount of tokens per backward pass | |
train_accumulation_steps = 1 | |
batch_size = args.batch_size | |
assert ddp_world_size == 1 or args.grad_accum == 1, "Cannot currently use both DDP and gradient accumulation" | |
if ddp_world_size > 1: | |
train_accumulation_steps = ddp_world_size | |
batch_size = args.batch_size // ddp_world_size | |
elif args.grad_accum > 1: | |
train_accumulation_steps *= args.grad_accum | |
batch_size = args.batch_size // args.grad_accum | |
print0(f'Train accumulation steps: {train_accumulation_steps}') | |
print0(f'Adjusted local batch size: {batch_size} tokens') | |
print0(f'Across {ddp_world_size} GPUs') | |
print0(f'Total batch size: {args.batch_size} tokens') | |
# load tokens | |
train_loader = DistributedPaddedDataLoader(args.input_bin, batch_size, ddp_rank, ddp_world_size, eos_id=2, pad_id=1) | |
valid_loader = DistributedPaddedDataLoader(args.input_valid_bin, batch_size, ddp_rank, ddp_world_size, eos_id=2, pad_id=1) | |
test_loader = DistributedPaddedDataLoader(args.input_test_bin, batch_size // 8, ddp_rank, ddp_world_size, eos_id=2, pad_id=1) | |
print0(f"Training DataLoader: total number of tokens: {train_loader.total_num_tokens} across {len(train_loader.files)} files") | |
print0(f"Validation DataLoader: total number of tokens: {valid_loader.total_num_tokens} across {len(valid_loader.files)} files") | |
print0(f"Testing DataLoader: total number of tokens: {test_loader.total_num_tokens} across {len(test_loader.files)} files") | |
print0('='*100, logonly=True) | |
valid_steps = valid_loader.total_num_tokens // args.batch_size | |
test_steps = test_loader.total_num_tokens // args.batch_size | |
input_ids = train_loader.next_batch() | |
model = ESM(model_config) | |
model = model.cuda().bfloat16() | |
for m in model.modules(): | |
if isinstance(m, CastedLinear): | |
m.float() | |
config.coordinate_descent_tuning = True # suggested by @Chillee | |
model = torch.compile(model) | |
# wrap model in DDP only if using distributed training | |
if ddp_world_size > 1: | |
model = DDP(model, device_ids=[ddp_local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) | |
raw_model = model.module | |
else: | |
raw_model = model | |
# init the optimizers | |
embed_params = [*raw_model.embed.parameters(), *raw_model.value_embeds.parameters()] | |
params = list(raw_model.blocks.parameters()) | |
matrix_params = [p for p in params if p.ndim == 2] | |
scalar_params = [p for p in params if p.ndim < 2] + [raw_model.skip_weights] | |
optimizer1 = torch.optim.Adam(embed_params, lr=0.6, betas=(0.8, 0.95), fused=True) | |
optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.008, betas=(0.8, 0.95), fused=True) | |
optimizer3 = Muon(matrix_params, lr=0.05, momentum=0.95) | |
optimizer4 = torch.optim.Adam(scalar_params, lr=0.04, betas=(0.8, 0.95), fused=True) | |
optimizers = [optimizer1, optimizer2, optimizer3, optimizer4] | |
# learning rate decay scheduler (linear warmup and cooldown) | |
def get_lr(it): | |
assert it <= args.num_steps | |
# 1) linear warmup for warmup_steps steps | |
if it < args.warmup_steps: | |
return (it+1) / args.warmup_steps | |
# 2) constant lr for a while | |
elif it < args.num_steps - args.cooldown_steps: | |
return 1.0 | |
# 3) linear cooldown | |
else: | |
decay_ratio = (args.num_steps - it) / args.cooldown_steps | |
return decay_ratio | |
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] | |
mlm_prob = torch.tensor(0.3, dtype=torch.float, device="cuda") | |
mlm_prob_prev = 0.3 | |
frac_mask = torch.tensor(0.5, dtype=torch.float, device="cuda") | |
frac_mask_prev = 0.5 | |
sliding_window_size = torch.tensor(1024 - 128, dtype=torch.int32, device="cuda") | |
sw_prev = 1024 - 128 | |
# Start training loop | |
training_time_ms = 0 | |
# start the clock | |
torch.cuda.synchronize() | |
t0 = time.perf_counter() | |
### BEGIN TRAINING LOOP ### | |
for step in range(args.num_steps + 1): | |
last_step = (step == args.num_steps) | |
# This effectively ignores timing first 10 steps, which are slower for weird reasons. | |
# Alternately, and slightly more correctly in terms of benchmarking, we could do 10 | |
# steps with dummy data first, and then re-initialize the model and reset the loader. | |
if step == 10: | |
training_time_ms = 0 | |
t0 = time.perf_counter() | |
timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val | |
# Linearly increase the sliding window size over training in chunks of 128 from 1024 -> 2048. By @fernbear.bsky.social | |
frac_done = step / args.num_steps # training progress | |
mlm_prob_new = int(((1 - frac_done) * 0.3 + frac_done * 0.15) * 100) / 100 | |
frac_mask_new = int(((1 - frac_done) * 0.5 + frac_done * 0.8) * 100) / 100 | |
sw_size = int(((1 - frac_done) * 1023 + frac_done * 2048) // 128) * 128 | |
if sw_size != sw_prev: | |
sliding_window_size.copy_(sw_size, non_blocking=True) | |
sw_prev = sw_size | |
if frac_mask_prev != frac_mask_new: | |
frac_mask.copy_(frac_mask_new, non_blocking=True) | |
frac_mask_prev = frac_mask_new | |
if mlm_prob_prev != mlm_prob_new: | |
mlm_prob.copy_(mlm_prob_new, non_blocking=True) | |
mlm_prob__prev = mlm_prob_new | |
# once in a while evaluate the validation dataset | |
if args.valid_loss_every > 0 and step % args.valid_loss_every == 0 or last_step: | |
# stop the clock | |
torch.cuda.synchronize() | |
training_time_ms += 1000 * (time.perf_counter() - t0) | |
# run validation batches | |
model.eval() | |
valid_loader.reset() | |
val_loss = 0.0 | |
with torch.no_grad(): | |
for _ in range(valid_steps): | |
input_ids = valid_loader.next_batch() | |
val_loss += model(input_ids, sliding_window_size) | |
if ddp_world_size > 1: | |
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) | |
val_loss /= valid_steps | |
# log val loss to console and to logfile | |
print0(f'step:{step}/{args.num_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms perplexity:{(math.e**val_loss):.4f} param_count:{get_param_count(model):,}') | |
# start the clock again | |
torch.cuda.synchronize() | |
t0 = time.perf_counter() | |
# save checkpoint every `save_every` steps | |
if master_process and args.save_every: | |
if last_step or (step % args.save_every == 0): | |
# stop the clock | |
torch.cuda.synchronize() | |
training_time_ms += 1000 * (time.perf_counter() - t0) | |
# save the state of the training process | |
log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) | |
torch.save(log, 'logs/state_step%06d.pt' % step) | |
try: | |
if ddp_world_size > 1: | |
model.module.push_to_hub(args.hf_model_name, subfolder='step%06d' % step) | |
else: | |
model.push_to_hub(args.hf_model_name, subfolder='step%06d' % step) | |
except Exception as e: | |
print(e) | |
torch.cuda.synchronize() | |
t0 = time.perf_counter() | |
if last_step: | |
break | |
# --------------- FORWARD AND BACKWARD PASS ----------------- | |
model.train() | |
for i in range(1, train_accumulation_steps + 1): | |
with contextlib.ExitStack() as stack: | |
if ddp_world_size > 1 and i < train_accumulation_steps: # there's no need to sync gradients every accumulation step | |
stack.enter_context(model.no_sync()) | |
#if step >= 5: | |
# stack.enter_context(torch.compiler.set_stance(skip_guard_eval_unsafe=True)) | |
model(input_ids, sliding_window_size, mlm_prob=mlm_prob, frac_mask=frac_mask).backward() | |
input_ids = train_loader.next_batch() | |
if train_accumulation_steps != 1: | |
for p in model.parameters(): | |
p.grad /= train_accumulation_steps | |
# momentum warmup for Muon | |
frac = min(step/300, 1) | |
for group in optimizer3.param_groups: | |
group['momentum'] = (1 - frac) * 0.85 + frac * 0.95 | |
# step the optimizers and schedulers | |
for opt, sched in zip(optimizers, schedulers): | |
opt.step() | |
sched.step() | |
# null the gradients | |
model.zero_grad(set_to_none=True) | |
# --------------- FORWARD AND BACKWARD PASS END ------------------- | |
# everything that follows now is just eval, diagnostics, prints, logging, etc. | |
if step % 100 == 0: | |
approx_time = training_time_ms + 1000 * (time.perf_counter() - t0) | |
print0(f"step:{step+1}/{args.num_steps} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") | |
print0(f"peak memory consumption training: {torch.cuda.max_memory_allocated() // 1024 // 1024 // 1024} GiB") | |
# save the model to huggingface | |
try: | |
if ddp_world_size > 1: | |
model.module.push_to_hub(args.hf_model_name) | |
else: | |
model.push_to_hub(args.hf_model_name) | |
except Exception as e: | |
print(e) | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
torch.manual_seed(42) | |
model.eval() | |
test_loader.reset() | |
test_loss = 0.0 | |
with torch.no_grad(): | |
for _ in range(test_steps): | |
input_ids = test_loader.next_batch() | |
test_loss += model(input_ids, sliding_window_size) | |
test_loss /= test_steps | |
print0(f"Test results | Loss: {test_loss:.4f} | Perplexity: {math.e**test_loss:.4f}") | |
print0(f"Total train time (min): {training_time_ms / 60000:.2f}") | |
print0(f"Total train time (hours): {training_time_ms / 3600000:.2f}") | |
print0(f"peak memory consumption testing: {torch.cuda.max_memory_allocated() // 1024 // 1024 // 1024} GiB") | |
# ------------------------------------------------------------------------- | |
# clean up nice | |
if ddp_world_size > 1: | |
dist.destroy_process_group() | |
import os | |
import torch | |
import torch.distributed as dist | |
### Muon optimizer | |
@torch.compile | |
def zeropower_via_newtonschulz5(G, steps): | |
""" | |
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | |
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | |
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | |
zero even beyond the point where the iteration no longer converges all the way to one everywhere | |
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | |
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | |
performance at all relative to UV^T, where USV^T = G is the SVD. | |
""" | |
assert len(G.shape) == 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
if G.size(0) > G.size(1): | |
X = X.T | |
# Ensure spectral norm is at most 1 | |
X = X / (X.norm() + 1e-7) | |
# Perform the NS iterations | |
for _ in range(steps): | |
A = X @ X.T | |
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng | |
X = a * X + B @ X | |
if G.size(0) > G.size(1): | |
X = X.T | |
return X | |
class Muon(torch.optim.Optimizer): | |
""" | |
Muon - MomentUm Orthogonalized by Newton-schulz | |
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- | |
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal | |
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has | |
the advantage that it can be stably run in bfloat16 on the GPU. | |
Some warnings: | |
- This optimizer assumes that all parameters passed in are 2D. | |
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D | |
parameters; those should all be optimized by a standard method (e.g., AdamW). | |
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. | |
- We believe it is unlikely to work well for training with small batch size. | |
- We believe it may not work well for finetuning pretrained models, but we haven't tested this. | |
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). | |
Arguments: | |
lr: The learning rate used by the internal SGD. | |
momentum: The momentum used by the internal SGD. | |
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) | |
ns_steps: The number of Newton-Schulz iteration steps to use. | |
""" | |
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): | |
self.world_size = int(os.environ.get('WORLD_SIZE', '1')) | |
self.rank = int(os.environ.get('RANK', '0')) | |
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) | |
params = list(params) | |
assert all(isinstance(p, torch.Tensor) for p in params) | |
sizes = {p.numel() for p in params} | |
param_groups = [ | |
{ | |
'params': [p for p in params if p.numel() == size], | |
'update_buffer': [ | |
torch.empty(size, device='cuda', dtype=torch.bfloat16) | |
for _ in range(self.world_size) | |
], | |
} | |
for size in sizes | |
] | |
super().__init__(param_groups, defaults) | |
def step(self): | |
for group in self.param_groups: | |
lr = group['lr'] | |
momentum = group['momentum'] | |
nesterov = group['nesterov'] | |
ns_steps = group['ns_steps'] | |
update_buffers = group['update_buffer'] | |
# generate weight updates in distributed fashion | |
params = group['params'] | |
assert len(params) % self.world_size == 0 | |
handle = None | |
params_world = None | |
def update_prev(): | |
if params_world is None: | |
return | |
if handle is not None: | |
handle.wait() | |
for p_world, g_world in zip(params_world, update_buffers): | |
p_world.data.add_( | |
g_world.view_as(p_world), | |
alpha=-lr * max(1, p_world.size(0) / p_world.size(1)) ** 0.5, | |
) | |
for base_i in range(len(params))[::self.world_size]: | |
p = params[base_i + self.rank] | |
g = p.grad | |
assert g is not None | |
state = self.state[p] | |
if 'momentum_buffer' not in state: | |
state['momentum_buffer'] = torch.zeros_like(g) | |
buf = state['momentum_buffer'] | |
buf.lerp_(g, 1 - momentum) | |
g = g.lerp_(buf, momentum) if nesterov else buf | |
g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten() | |
update_prev() | |
if self.world_size > 1: | |
handle = dist.all_gather(update_buffers, g, async_op=True) | |
else: | |
update_buffers[0].copy_(g) | |
handle = None | |
params_world = params[base_i : base_i + self.world_size] | |
update_prev() | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.attention.flex_attention import flex_attention, create_block_mask | |
from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel | |
from typing import Optional, Tuple, List, Any | |
try: | |
from .utils import ProteinMasker | |
except ImportError: | |
from utils import ProteinMasker | |
class ModelConfig(PretrainedConfig): | |
""" | |
33 tokens: https://huggingface.co/Synthyra/ESMplusplus_large/blob/main/modeling_esm_plusplus.py#L868-L874 | |
ESM2-8M has 6 layers, 20 heads, 320 hidden dim: https://huggingface.co/facebook/esm2_t6_8M_UR50D/blob/main/config.json | |
ESM2-35M has 12 layers, 20 heads, 480 hidden dim: https://huggingface.co/facebook/esm2_t12_35M_UR50D/blob/main/config.json | |
ESM2-150M has 30 layers, 20 heads, 640 hidden dim: https://huggingface.co/facebook/esm2_t30_150M_UR50D/blob/main/config.json | |
ESM2-650M has 33 layers, 20 heads, 1280 hidden dim: https://huggingface.co/facebook/esm2_t33_650M_UR50D/blob/main/config.json | |
""" | |
def __init__( | |
self, | |
vocab_size=33, | |
hidden_size=768, | |
num_hidden_layers=12, | |
num_attention_heads=12, | |
expansion_ratio=8/3, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.num_hidden_layers = num_hidden_layers | |
self.num_attention_heads = num_attention_heads | |
self.expansion_ratio = expansion_ratio | |
def norm(x: torch.Tensor) -> torch.Tensor: | |
return F.rms_norm(x, (x.size(-1),)) | |
class CastedLinear(nn.Linear): | |
def __init__(self, in_features, out_features): | |
super().__init__(in_features, out_features, bias=False) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return F.linear(x, self.weight.to(x.dtype)) | |
class Rotary(nn.Module): | |
def __init__(self, dim, base=10000): | |
super().__init__() | |
self.register_buffer('inv_freq', (1 / base) ** (torch.arange(0, dim, 2) / dim)) | |
self.seq_len_cached = None | |
self.cos_cached = None | |
self.sin_cached = None | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
seq_len = x.shape[1] | |
if seq_len != self.seq_len_cached: | |
t = torch.arange(seq_len, device=x.device) | |
freqs = torch.outer(t, self.inv_freq) | |
self.seq_len_cached = seq_len | |
self.cos_cached = freqs.cos() | |
self.sin_cached = freqs.sin() | |
cos, sin = self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] | |
# apply_rotary_emb(x, cos, sin) | |
x1, x2 = x.chunk(2, dim=3) | |
y1 = x1 * cos + x2 * sin | |
y2 = x1 * (-sin) + x2 * cos | |
return torch.cat((y1, y2), 3).type_as(x) | |
class SelfAttention(nn.Module): | |
""" | |
TODO | |
Add F.spda option | |
Add causal option (flex and sdpa) | |
""" | |
def __init__(self, dim, num_attention_heads): | |
super().__init__() | |
assert dim % num_attention_heads == 0 | |
self.num_attention_heads = num_attention_heads | |
self.qkv = CastedLinear(dim, 3 * dim) | |
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) | |
self.rotary = Rotary(dim // num_attention_heads) # dim // num_attention_heads = head_dim | |
self.o_proj = CastedLinear(dim, dim) | |
self.o_proj.weight.data.zero_() # zero init suggested by @Grad62304977 | |
def forward_sdpa(self, x: torch.Tensor, vi: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
""" | |
TODO | |
Question? Is this output actually different than flex attention output? | |
Likely yes because of scoremod and / or soft capping | |
Would be good to be able to do inference this way for typical PLM inference pipelines | |
https://pytorch.org/blog/flexattention/ | |
""" | |
B, T = x.size(0), x.size(1) # batch size, sequence length | |
qkv = self.qkv(x) | |
q, k, v = qkv.chunk(3, dim=-1) | |
q = q.view(B, T, self.num_attention_heads, -1) | |
k = k.view(B, T, self.num_attention_heads, -1) | |
v = v.view(B, T, self.num_attention_heads, -1) | |
v = self.lambdas[0] * v + self.lambdas[1] * vi.view_as(v) # @KoszarskyB & @Grad62304977 | |
q, k = norm(q), norm(k) # QK norm @Grad62304977 | |
q, k = self.rotary(q), self.rotary(k) | |
y = F.scaled_dot_product_attention( | |
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), | |
attn_mask=attention_mask, | |
dropout_p=0.0, | |
is_causal=False, | |
enable_gqa=True | |
) | |
y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side | |
y = self.o_proj(y) | |
return y | |
def forward(self, x: torch.Tensor, vi: torch.Tensor, block_mask: torch.Tensor) -> torch.Tensor: | |
B, T = x.size(0), x.size(1) # batch size, sequence length | |
assert B == 1, "Must use batch size = 1 for FlexAttention" | |
qkv = self.qkv(x) | |
q, k, v = qkv.chunk(3, dim=-1) | |
q = q.view(B, T, self.num_attention_heads, -1) | |
k = k.view(B, T, self.num_attention_heads, -1) | |
v = v.view(B, T, self.num_attention_heads, -1) | |
v = self.lambdas[0] * v + self.lambdas[1] * vi.view_as(v) # @KoszarskyB & @Grad62304977 | |
q, k = norm(q), norm(k) # QK norm @Grad62304977 | |
q, k = self.rotary(q), self.rotary(k) | |
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, enable_gqa=True, kernel_options = { | |
"BLOCK_M": 64, "BLOCK_N": 64, # forward | |
"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32 # backwards | |
}) | |
y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side | |
y = self.o_proj(y) | |
return y | |
def correction_fn(expansion_ratio: float, d_model: int) -> int: | |
return int(((expansion_ratio * d_model) + 255) // 256 * 256) | |
class MLP(nn.Module): | |
def __init__(self, dim, expansion_ratio): | |
super().__init__() | |
intermediate_dim = 828 | |
self.up = CastedLinear(dim, intermediate_dim) | |
self.down = CastedLinear(intermediate_dim, dim) | |
self.down.weight.data.zero_() # zero init suggested by @Grad62304977 | |
self.relu = nn.ReLU() | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# https://arxiv.org/abs/2109.08668v2 | |
# ReLU squared ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 | |
return self.down(self.relu(self.up(x)).square()) | |
class Block(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.attn = SelfAttention(config.hidden_size, config.num_attention_heads) | |
self.mlp = MLP(config.hidden_size, config.expansion_ratio) | |
self.lambdas = nn.Parameter(torch.tensor([1., 0.])) | |
def sdpa_forward(self, x: torch.Tensor, vi: torch.Tensor, x0: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
x = self.lambdas[0] * x + self.lambdas[1] * x0 | |
x = x + self.attn.forward_sdpa(norm(x), vi, attention_mask) | |
x = x + self.mlp(norm(x)) | |
return x | |
def forward(self, x: torch.Tensor, vi: torch.Tensor, x0: torch.Tensor, block_mask: torch.Tensor) -> torch.Tensor: | |
x = self.lambdas[0] * x + self.lambdas[1] * x0 | |
x = x + self.attn(norm(x), vi, block_mask) | |
x = x + self.mlp(norm(x)) | |
return x | |
class ValueEmbedding(nn.Module): | |
def __init__(self, config: "ModelConfig", padding_idx): | |
super().__init__() | |
self.embed = nn.ModuleList([ | |
nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=padding_idx) | |
for _ in range(config.num_hidden_layers // 2) | |
]) | |
def forward(self, inputs: torch.Tensor) -> List[torch.Tensor]: | |
ve = [emb(inputs) for emb in self.embed] | |
ve += reversed(ve) | |
return ve | |
class ESM(PreTrainedModel): | |
""" | |
TODO | |
Add causal option (flex and sdpa) | |
""" | |
config_class = ModelConfig | |
def __init__(self, config: ModelConfig): | |
super().__init__(config) | |
self.config = config | |
tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') | |
self.masker = ProteinMasker(tokenizer, 0.20) # 20% masking rate https://arxiv.org/abs/2301.06568 | |
self.inference_masker = ProteinMasker(tokenizer, 0.15) # 15% masking rate for inference, ESM2 | |
self.cls_id = tokenizer.cls_token_id | |
self.vocab_size = tokenizer.vocab_size | |
self.num_hidden_layers = config.num_hidden_layers | |
# U-net design by @brendanh0gan | |
assert config.num_hidden_layers % 2 == 0, "Number of layers should be even for U-net design" | |
self.num_encoder_layers = config.num_hidden_layers // 2 # Half of the layers for encoder | |
self.num_decoder_layers = config.num_hidden_layers - self.num_encoder_layers # Remaining for decoder | |
# Add learnable skip connection weights for decoder layers | |
self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) | |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size, padding_idx=tokenizer.pad_token_id) | |
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)]) | |
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning | |
# U-net structure on token value embeddings by @leloykun | |
self.value_embeds = ValueEmbedding(config, padding_idx=tokenizer.pad_token_id) | |
self.lm_head = CastedLinear(config.hidden_size, self.vocab_size) | |
self.lm_head.weight.data.zero_() # @Grad62304977 | |
self.cross_entropy = nn.CrossEntropyLoss() | |
def get_logits(self, x: torch.Tensor) -> torch.Tensor: | |
x = norm(x) | |
logits = self.lm_head(x) | |
logits = 30 * torch.tanh(logits / 30) # @Grad62304977 | |
logits = logits.float() | |
return logits | |
def encoder_pass(self, input_ids: torch.Tensor, sliding_window_size: torch.Tensor) -> torch.Tensor: | |
input_ids = input_ids.flatten() # flex_attention needs batch 1 | |
docs = (input_ids == self.cls_id).cumsum(0) | |
def doc_mask_mod(b, h, q_idx, kv_idx): | |
bidirectional_sliding_window_mask = torch.abs(q_idx - kv_idx) < sliding_window_size | |
doc_mask = docs[q_idx] == docs[kv_idx] | |
return bidirectional_sliding_window_mask & doc_mask | |
S = len(input_ids) | |
block_mask = create_block_mask(doc_mask_mod, None, None, S, S) | |
x = self.embed(input_ids[None]) | |
x = norm(x) # @Grad62304977 | |
x0 = x | |
ve = self.value_embeds(input_ids) | |
ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:] | |
# Store outputs for U-Net skip connections | |
skip_connections = [] | |
# Encoder pass - process only the first half of the blocks | |
for i in range(self.num_encoder_layers): | |
x = self.blocks[i](x, ve_enc[i], x0, block_mask) | |
skip_connections.append(x) | |
# Decoder pass - process the remaining blocks with weighted skip connections | |
for i in range(self.num_decoder_layers): | |
x = x + self.skip_weights[i] * skip_connections.pop() | |
# U-net structure on token value embeddings by @leloykun | |
x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_mask) | |
return x | |
def get_vector_embeddings(self, input_ids: torch.Tensor, sliding_window_size: torch.Tensor) -> torch.Tensor: | |
docs = (input_ids == self.cls_id).cumsum(dim=0) # shape: [S] | |
x = self.encoder_pass(input_ids, sliding_window_size) | |
x = x.view(-1, self.config.hidden_size) | |
# At this point, x is shape [S, hidden_size] | |
# We want to mean-pool across each document index. | |
# Convert docs to 0-based so we can do nice indexing | |
num_docs = docs.max().item() | |
doc_ids = docs - 1 # Now documents are labeled [0, 1, 2, ...] | |
# Mean-pool across tokens belonging to each doc | |
doc_embeds = [] | |
for doc_idx in range(num_docs): | |
mask = (doc_ids == doc_idx) | |
# Collect all token embeddings for this doc and average | |
doc_embeds.append(x[mask].mean(dim=0)) | |
# Stack into [num_documents, hidden_size] | |
return torch.stack(doc_embeds, dim=0) | |
def inference(self, input_ids: torch.Tensor, sliding_window_size: torch.Tensor = None) -> Tuple[torch.Tensor, Any, Any]: | |
input_ids, labels = self.inference_masker(input_ids) | |
last_hidden = self.encoder_pass(input_ids, sliding_window_size) | |
logits = self.get_logits(last_hidden) | |
loss = None | |
if labels is not None: | |
loss = self.cross_entropy(logits.view(-1, self.vocab_size), labels.view(-1).long()) | |
return logits, loss, labels | |
def forward(self, input_ids: torch.Tensor, sliding_window_size: torch.Tensor, mlm_prob=None, frac_mask=None) -> torch.Tensor: | |
input_ids, labels = self.masker(input_ids, mlm_prob, frac_mask) | |
last_hidden = self.encoder_pass(input_ids, sliding_window_size) | |
logits = self.get_logits(last_hidden) | |
return self.cross_entropy(logits.view(-1, self.vocab_size), labels.view(-1).long()) | |
if __name__ == '__main__': | |
""" | |
TODO | |
look at MSE between flex attention outputs and sdpa outputs | |
""" | |
import torch | |
from typing import Tuple, Optional | |
""" | |
Standardized MLM masking approach for consistency | |
""" | |
class ProteinMasker: | |
def __init__(self, tokenizer, mlm_probability=0.15): | |
""" | |
Initialize the ProteinMasker with the given tokenizer and masking parameters. | |
Of the masked tokens, 80% are replaced with [MASK], 10% are replaced with a random amino acid token, and 10% are unchanged. | |
""" | |
self.tokenizer = tokenizer | |
self.mlm_probability = torch.tensor(mlm_probability) | |
self.mask_token_id = tokenizer.mask_token_id | |
self.special_tokens = torch.tensor(tokenizer.all_special_ids) | |
canonical_amino_acids = 'ACDEFGHIKLMNPQRSTVWY' | |
canonical_amino_acids_ids = tokenizer.convert_tokens_to_ids(list(canonical_amino_acids)) | |
self.low_range = min(canonical_amino_acids_ids) | |
self.high_range = max(canonical_amino_acids_ids) | |
def __call__(self, input_ids: torch.Tensor, mlm_probability=None, frac_masked=None) -> Tuple[torch.Tensor, torch.Tensor]: | |
mlm_probability = mlm_probability or self.mlm_probability | |
frac_masked = frac_masked or torch.tensor(0.8, device=input_ids.device) | |
labels = input_ids.clone() | |
# Create special tokens mask using broadcasting | |
special_tokens = self.special_tokens.to(input_ids.device) | |
special_tokens_mask = (input_ids[..., None] == special_tokens).any(-1) | |
# Create probability matrix and mask special tokens | |
probability_matrix = torch.ones_like(labels, dtype=torch.float) * mlm_probability | |
probability_matrix.masked_fill_(special_tokens_mask, value=0.0) | |
# Create masked indices | |
masked_indices = torch.bernoulli(probability_matrix).bool() | |
labels[~masked_indices] = -100 # We only compute loss on masked tokens | |
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) | |
indices_replaced = torch.bernoulli( | |
torch.ones_like(probability_matrix, dtype=torch.float) * frac_masked | |
).bool() & masked_indices | |
input_ids[indices_replaced] = self.mask_token_id | |
# 10% of the time, we replace masked input tokens with random word | |
indices_random = torch.bernoulli(torch.full_like(probability_matrix, 0.5)).bool() & masked_indices & ~indices_replaced | |
random_words = torch.randint(low=self.low_range, high=self.high_range, size=labels.shape, dtype=input_ids.dtype, device=labels.device) | |
input_ids[indices_random] = random_words[indices_random] | |
# The rest of the time (10% of the time) we keep the masked input tokens unchanged | |
return input_ids, labels | |
if __name__ == "__main__": | |
from transformers import EsmTokenizer | |
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") | |
test_seqs = [ | |
'MNFKYKLYSYITIFQIILILPTIVASNERCIALGGVCKDFSDCTGNYKPIDKHCDGSNNIKCCIRKIECPTSQNSNFTISGKNKEDEALPFIFKSEGGCQNDKNDNGNKINGKIGYTCAGITPMVGWKNKENYFSYAIKECTNDTNFTYCAYKLNENKFREGAKNIYIDKYAVAGKCNNLPQPAYYVCFDTSVNHGSGWSSKTITANPIGNMDGREYGLLLNKKSREKYINIVKNDSSQEKYLNGWLSRADDREKYCNNYCTSNCNCDNSASKASVSSNTNTTDIYNSVNTVDSDICNCDDNEPTDFLDDDYINNEEEIDEEIIDQEEY', | |
'MYRTALYFTVCSIWLCQIITGVLSLKCKCDLCKDKNYTCITDGYCYTSATLKDGVILYNYRCLDLNFPMRNPMFCHKQIPIHHEFTLECCNDRDFCNIRLVPKLTPKDNATSDTSLGTIEIAVVIILPTLVICIIAMAIYLYYQNKRSTHHHLGLGDDSIEAPDHPILNGVSLKHMIEMTTSGSGSGLPLLVQRSIARQIQLVEIIGQGRYGEVWRGRWRGENVAVKIFSSREERSWFREAEIYQTVMLRHDNILGFIAADNKGVLSLKCKCDLCKDKNYTCITDGYCYTSATLKDGVILYNYRQLGASLNRFXVYALGLIFWEISRRCNVGGIYDEYQLPFYDAVPSDPTIEEMRRVVCVERQRPSIPNRWQSCEALHVMSKLMKECWYHNATARLTALRIKKTLANFRASEELKM' | |
] | |
test_ids = tokenizer(test_seqs, return_tensors="pt", padding=True).input_ids | |
masker = ProteinMasker(tokenizer, mlm_probability=0.5) | |
print(masker.mask_token_id) | |
print(masker.special_tokens) | |
print(masker.low_range, masker.high_range) | |
# First set of masking | |
masked_ids1, labels1 = masker(test_ids.clone()) | |
masked_ids2, labels2 = masker(test_ids.clone()) | |
print("Before setting seed:") | |
print("Original: ", test_ids[0][:20].tolist()) | |
print("Masking 1:", masked_ids1[0][:20].tolist()) | |
print("Masking 2:", masked_ids2[0][:20].tolist()) | |
print("Are they equal?", torch.equal(masked_ids1, masked_ids2)) | |
# Now with seed | |
torch.manual_seed(42) | |
masked_ids3, labels3 = masker(test_ids.clone()) | |
torch.manual_seed(42) | |
masked_ids4, labels4 = masker(test_ids.clone()) | |
print("\nAfter setting seed:") | |
print("Original: ", test_ids[0][:20].tolist()) | |
print("Masking 3:", masked_ids3[0][:20].tolist()) | |
print("Masking 4:", masked_ids4[0][:20].tolist()) | |
print("Are they equal?", torch.equal(masked_ids3, masked_ids4)) | |
import torch | |
from pathlib import Path | |
def _peek_data_shard(file: Path): | |
# only reads the header, returns header data | |
# header is 256 int32 | |
header = torch.from_file(f"{file}", False, 256, dtype=torch.int32) | |
assert header[0] == 20240520, "magic number mismatch in the data .bin file" | |
assert header[1] == 1, "unsupported version" | |
return int(header[2]) # number of tokens (claimed) | |
def _load_data_shard(path: Path, num_tokens): | |
with path.open("rb", buffering=0) as f: | |
tokens = torch.empty(num_tokens, dtype=torch.uint8, pin_memory=True) | |
f.seek(256 * 4) | |
nbytes = f.readinto(tokens.numpy()) | |
assert nbytes == num_tokens, "number of tokens read does not match header?" | |
return tokens | |
class DistributedDataLoader: | |
def __init__(self, filename_pattern, batch_size, process_rank, num_processes): | |
self.process_rank = process_rank | |
self.num_processes = num_processes | |
self.batch_size = batch_size | |
# glob files that match the pattern | |
self.files = sorted(Path.cwd().glob(filename_pattern)) | |
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" | |
# load and validate all data shards, count number of tokens in total | |
self.files_num_tokens = [_peek_data_shard(file) for file in self.files] | |
self.total_num_tokens = sum(self.files_num_tokens) | |
self.reset() | |
def reset(self): | |
self.current_shard = -1 | |
self.advance() | |
def advance(self): # advance to next data shard | |
self.current_shard = (self.current_shard + 1) % len(self.files) | |
self.current_position = self.process_rank * self.batch_size | |
self.tokens = _load_data_shard(self.files[self.current_shard], self.files_num_tokens[self.current_shard]) | |
def next_batch(self): | |
batch_size = self.batch_size * self.num_processes | |
buf = self.tokens[self.current_position:self.current_position+self.batch_size] | |
# host side async is sufficient; | |
# no performance improvement was observed when introducing a separate stream. | |
input_ids = buf.to(device="cuda", dtype=torch.int32, non_blocking=True) # inputs | |
# advance current position and load next shard if necessary | |
self.current_position += batch_size | |
if self.current_position + batch_size >= len(self.tokens): | |
self.advance() | |
return input_ids | |
class DistributedPaddedDataLoader(DistributedDataLoader): | |
def __init__(self, filename_pattern, seq_len, process_rank, num_processes, eos_id, pad_id): | |
super().__init__(filename_pattern, seq_len, process_rank, num_processes) | |
assert (eos_id is None) == (pad_id is None) | |
self.eos_id = eos_id | |
self.pad_id = pad_id | |
self.padding = (eos_id is not None) and (pad_id is not None) | |
def reset(self): | |
self.current_shard = self.process_rank - self.num_processes | |
self.advance() | |
def advance(self): # advance to next data shard | |
self.current_shard = (self.current_shard + self.num_processes) % len(self.files) | |
self.current_position = 0 | |
self.tokens = _load_data_shard(self.files[self.current_shard], self.files_num_tokens[self.current_shard]) | |
def next_batch(self): | |
end_pos = self.current_position + self.batch_size | |
buf = self.tokens[self.current_position : end_pos] | |
input_ids = buf.to(device="cuda", dtype=torch.int32, non_blocking=True) | |
keep = (input_ids == self.eos_id).cumsum(dim=0).argmax().item() | |
keep = max(keep or 0, self.batch_size - 2048) | |
input_ids[keep + 1:] = self.pad_id | |
# advance current position and load next shard if necessary | |
self.current_position += keep | |
if self.current_position + self.batch_size >= len(self.tokens): | |
self.advance() | |
return input_ids | |
==================================================================================================== | |
Running python 3.11.10 | packaged by conda-forge | (main, Oct 16 2024, 01:27:36) [GCC 13.3.0] | |
Running pytorch 2.6.0.dev20241203+cu124 compiled for CUDA 12.4 | |
nvidia-smi: | |
Tue Dec 31 02:54:59 2024 | |
+-----------------------------------------------------------------------------------------+ | |
| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 | | |
|-----------------------------------------+------------------------+----------------------+ | |
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | |
| | | MIG M. | | |
|=========================================+========================+======================| | |
| 0 NVIDIA GeForce RTX 4090 On | 00000000:42:00.0 Off | Off | | |
| 33% 31C P2 41W / 450W | 1756MiB / 24564MiB | 89% Default | | |
| | | N/A | | |
+-----------------------------------------+------------------------+----------------------+ | |
| 1 NVIDIA GeForce RTX 4090 On | 00000000:81:00.0 Off | Off | | |
| 40% 34C P2 40W / 450W | 591MiB / 24564MiB | 0% Default | | |
| | | N/A | | |
+-----------------------------------------+------------------------+----------------------+ | |
| 2 NVIDIA GeForce RTX 4090 On | 00000000:82:00.0 Off | Off | | |
| 31% 30C P2 48W / 450W | 591MiB / 24564MiB | 0% Default | | |
| | | N/A | | |
+-----------------------------------------+------------------------+----------------------+ | |
| 3 NVIDIA GeForce RTX 4090 On | 00000000:C1:00.0 Off | Off | | |
| 40% 34C P2 33W / 450W | 591MiB / 24564MiB | 0% Default | | |
| | | N/A | | |
+-----------------------------------------+------------------------+----------------------+ | |
+-----------------------------------------------------------------------------------------+ | |
| Processes: | | |
| GPU GI CI PID Type Process name GPU Memory | | |
| ID ID Usage | | |
|=========================================================================================| | |
+-----------------------------------------------------------------------------------------+ | |
==================================================================================================== | |
Model config: ModelConfig { | |
"expansion_ratio": 2.6666666666666665, | |
"hidden_size": 768, | |
"num_attention_heads": 6, | |
"num_hidden_layers": 12, | |
"transformers_version": "4.47.1", | |
"vocab_size": 33 | |
} | |
Args: {'vocab_size': 33, 'num_hidden_layers': 12, 'num_attention_heads': 6, 'hidden_size': 768, 'input_bin': 'data/omgprot50/omgprot50_train_*.bin', 'input_valid_bin': 'data/omgprot50/omgprot50_valid_*.bin', 'input_test_bin': 'data/omgprot50/omgprot50_test_*.bin', 'batch_size': 262144, 'grad_accum': 1, 'num_steps': 20000, 'warmup_steps': 1000, 'cooldown_steps': 1000, 'valid_loss_every': 100, 'hf_model_name': 'lapp0/esm2_speedrun', 'token': None, 'save_every': 1000} | |
Train accumulation steps: 4 | |
Adjusted local batch size: 65536 tokens | |
Across 4 GPUs | |
Total batch size: 262144 tokens | |
Training DataLoader: total number of tokens: 4200000000 across 42 files | |
Validation DataLoader: total number of tokens: 2097660 across 1 files | |
Testing DataLoader: total number of tokens: 3686279 across 1 files | |
==================================================================================================== | |
step:0/20000 val_loss:3.4965 train_time:0ms step_avg:nanms perplexity:33.0000 param_count:43,776,054 | |
step:1/20000 train_time:29668ms step_avg:nanms | |
step:100/20000 val_loss:2.6435 train_time:112696ms step_avg:1252.18ms perplexity:14.0618 param_count:43,776,054 | |
step:101/20000 train_time:113932ms step_avg:1252.00ms | |
step:200/20000 val_loss:2.6045 train_time:238036ms step_avg:1252.82ms perplexity:13.5241 param_count:43,776,054 | |
step:201/20000 train_time:239274ms step_avg:1252.74ms | |
step:300/20000 val_loss:2.5824 train_time:363475ms step_avg:1253.36ms perplexity:13.2288 param_count:43,776,054 | |
step:301/20000 train_time:364722ms step_avg:1253.34ms | |
step:400/20000 val_loss:2.5603 train_time:489191ms step_avg:1254.34ms perplexity:12.9398 param_count:43,776,054 | |
step:401/20000 train_time:490439ms step_avg:1254.32ms | |
step:500/20000 val_loss:2.5402 train_time:614711ms step_avg:1254.51ms perplexity:12.6825 param_count:43,776,054 | |
step:501/20000 train_time:615965ms step_avg:1254.51ms | |
step:600/20000 val_loss:2.5306 train_time:740145ms step_avg:1254.48ms perplexity:12.5605 param_count:43,776,054 | |
step:601/20000 train_time:741395ms step_avg:1254.48ms | |
step:700/20000 val_loss:2.5129 train_time:865743ms step_avg:1254.70ms perplexity:12.3402 param_count:43,776,054 | |
step:701/20000 train_time:866991ms step_avg:1254.69ms | |
step:800/20000 val_loss:2.5063 train_time:991379ms step_avg:1254.91ms perplexity:12.2595 param_count:43,776,054 | |
step:801/20000 train_time:992615ms step_avg:1254.89ms | |
step:900/20000 val_loss:2.4996 train_time:1116893ms step_avg:1254.94ms perplexity:12.1779 param_count:43,776,054 | |
step:901/20000 train_time:1118150ms step_avg:1254.94ms | |
step:1000/20000 val_loss:2.4882 train_time:1242239ms step_avg:1254.79ms perplexity:12.0396 param_count:43,776,054 | |
step:1001/20000 train_time:1243470ms step_avg:1254.76ms | |
step:1100/20000 val_loss:2.4830 train_time:1367778ms step_avg:1254.84ms perplexity:11.9772 param_count:43,776,054 | |
step:1101/20000 train_time:1369036ms step_avg:1254.84ms | |
step:1200/20000 val_loss:2.4750 train_time:1493260ms step_avg:1254.84ms perplexity:11.8811 param_count:43,776,054 | |
step:1201/20000 train_time:1494511ms step_avg:1254.84ms | |
step:1300/20000 val_loss:2.4633 train_time:1618605ms step_avg:1254.73ms perplexity:11.7438 param_count:43,776,054 | |
step:1301/20000 train_time:1619852ms step_avg:1254.73ms | |
step:1400/20000 val_loss:2.4593 train_time:1744014ms step_avg:1254.69ms perplexity:11.6963 param_count:43,776,054 | |
step:1401/20000 train_time:1745257ms step_avg:1254.68ms | |
step:1500/20000 val_loss:2.4555 train_time:1869543ms step_avg:1254.73ms perplexity:11.6517 param_count:43,776,054 | |
step:1501/20000 train_time:1870799ms step_avg:1254.73ms | |
step:1600/20000 val_loss:2.4459 train_time:1995035ms step_avg:1254.74ms perplexity:11.5404 param_count:43,776,054 | |
step:1601/20000 train_time:1996290ms step_avg:1254.74ms | |
step:1700/20000 val_loss:2.4421 train_time:2120464ms step_avg:1254.71ms perplexity:11.4971 param_count:43,776,054 | |
step:1701/20000 train_time:2121711ms step_avg:1254.71ms | |
step:1800/20000 val_loss:2.4366 train_time:2246044ms step_avg:1254.77ms perplexity:11.4339 param_count:43,776,054 | |
step:1801/20000 train_time:2247296ms step_avg:1254.77ms | |
step:1900/20000 val_loss:2.4320 train_time:2371625ms step_avg:1254.83ms perplexity:11.3814 param_count:43,776,054 | |
step:1901/20000 train_time:2372861ms step_avg:1254.82ms | |
step:2000/20000 val_loss:2.4286 train_time:2497120ms step_avg:1254.83ms perplexity:11.3425 param_count:43,776,054 | |
step:2001/20000 train_time:2498345ms step_avg:1254.82ms | |
step:2100/20000 val_loss:2.4216 train_time:2622502ms step_avg:1254.79ms perplexity:11.2643 param_count:43,776,054 | |
step:2101/20000 train_time:2623751ms step_avg:1254.78ms | |
step:2200/20000 val_loss:2.4191 train_time:2747866ms step_avg:1254.73ms perplexity:11.2354 param_count:43,776,054 | |
step:2201/20000 train_time:2749125ms step_avg:1254.74ms | |
step:2300/20000 val_loss:2.4125 train_time:2873381ms step_avg:1254.75ms perplexity:11.1623 param_count:43,776,054 | |
step:2301/20000 train_time:2874635ms step_avg:1254.75ms | |
step:2400/20000 val_loss:2.4094 train_time:2998857ms step_avg:1254.75ms perplexity:11.1270 param_count:43,776,054 | |
step:2401/20000 train_time:3000102ms step_avg:1254.75ms | |
step:2500/20000 val_loss:2.4045 train_time:3124269ms step_avg:1254.73ms perplexity:11.0728 param_count:43,776,054 | |
step:2501/20000 train_time:3125527ms step_avg:1254.73ms | |
step:2600/20000 val_loss:2.4022 train_time:3249942ms step_avg:1254.80ms perplexity:11.0469 param_count:43,776,054 | |
step:2601/20000 train_time:3251232ms step_avg:1254.82ms | |
step:2700/20000 val_loss:2.4009 train_time:3375544ms step_avg:1254.85ms perplexity:11.0331 param_count:43,776,054 | |
step:2701/20000 train_time:3376789ms step_avg:1254.85ms | |
step:2800/20000 val_loss:2.3951 train_time:3501249ms step_avg:1254.93ms perplexity:10.9688 param_count:43,776,054 | |
step:2801/20000 train_time:3502494ms step_avg:1254.92ms | |
step:2900/20000 val_loss:2.3942 train_time:3626947ms step_avg:1255.00ms perplexity:10.9592 param_count:43,776,054 | |
step:2901/20000 train_time:3628193ms step_avg:1255.00ms | |
step:3000/20000 val_loss:2.3905 train_time:3752644ms step_avg:1255.06ms perplexity:10.9194 param_count:43,776,054 | |
step:3001/20000 train_time:3753878ms step_avg:1255.06ms | |
step:3100/20000 val_loss:2.3909 train_time:3878438ms step_avg:1255.16ms perplexity:10.9233 param_count:43,776,054 | |
step:3101/20000 train_time:3879683ms step_avg:1255.15ms | |
step:3200/20000 val_loss:2.3841 train_time:4004160ms step_avg:1255.22ms perplexity:10.8492 param_count:43,776,054 | |
step:3201/20000 train_time:4005425ms step_avg:1255.23ms | |
step:3300/20000 val_loss:2.3822 train_time:4130050ms step_avg:1255.33ms perplexity:10.8286 param_count:43,776,054 | |
step:3301/20000 train_time:4131299ms step_avg:1255.33ms | |
step:3400/20000 val_loss:2.3801 train_time:4255956ms step_avg:1255.44ms perplexity:10.8056 param_count:43,776,054 | |
step:3401/20000 train_time:4257212ms step_avg:1255.44ms | |
step:3500/20000 val_loss:2.3773 train_time:4381807ms step_avg:1255.53ms perplexity:10.7758 param_count:43,776,054 | |
step:3501/20000 train_time:4383066ms step_avg:1255.53ms | |
step:3600/20000 val_loss:2.3731 train_time:4507559ms step_avg:1255.59ms perplexity:10.7303 param_count:43,776,054 | |
step:3601/20000 train_time:4508813ms step_avg:1255.59ms | |
step:3700/20000 val_loss:2.3717 train_time:4633642ms step_avg:1255.73ms perplexity:10.7154 param_count:43,776,054 | |
step:3701/20000 train_time:4634895ms step_avg:1255.73ms | |
step:3800/20000 val_loss:2.3681 train_time:4759427ms step_avg:1255.79ms perplexity:10.6771 param_count:43,776,054 | |
step:3801/20000 train_time:4760679ms step_avg:1255.78ms | |
step:3900/20000 val_loss:2.3657 train_time:4885386ms step_avg:1255.88ms perplexity:10.6511 param_count:43,776,054 | |
step:3901/20000 train_time:4886660ms step_avg:1255.89ms | |
step:4000/20000 val_loss:2.3663 train_time:5011049ms step_avg:1255.90ms perplexity:10.6579 param_count:43,776,054 | |
step:4001/20000 train_time:5012294ms step_avg:1255.90ms | |
step:4100/20000 val_loss:2.3610 train_time:5136585ms step_avg:1255.89ms perplexity:10.6020 param_count:43,776,054 | |
step:4101/20000 train_time:5137845ms step_avg:1255.89ms | |
step:4200/20000 val_loss:2.3606 train_time:5262181ms step_avg:1255.89ms perplexity:10.5969 param_count:43,776,054 | |
step:4201/20000 train_time:5263432ms step_avg:1255.89ms | |
step:4300/20000 val_loss:2.3555 train_time:5387708ms step_avg:1255.88ms perplexity:10.5436 param_count:43,776,054 | |
step:4301/20000 train_time:5388969ms step_avg:1255.88ms | |
step:4400/20000 val_loss:2.3566 train_time:5513415ms step_avg:1255.90ms perplexity:10.5545 param_count:43,776,054 | |
step:4401/20000 train_time:5514669ms step_avg:1255.90ms | |
step:4500/20000 val_loss:2.3564 train_time:5639260ms step_avg:1255.96ms perplexity:10.5525 param_count:43,776,054 | |
step:4501/20000 train_time:5640506ms step_avg:1255.96ms | |
step:4600/20000 val_loss:2.3517 train_time:5765288ms step_avg:1256.05ms perplexity:10.5034 param_count:43,776,054 | |
step:4601/20000 train_time:5766534ms step_avg:1256.05ms | |
step:4700/20000 val_loss:2.3511 train_time:5891197ms step_avg:1256.12ms perplexity:10.4976 param_count:43,776,054 | |
step:4701/20000 train_time:5892453ms step_avg:1256.12ms | |
step:4800/20000 val_loss:2.3451 train_time:6016871ms step_avg:1256.13ms perplexity:10.4347 param_count:43,776,054 | |
step:4801/20000 train_time:6018130ms step_avg:1256.13ms | |
step:4900/20000 val_loss:2.3495 train_time:6142542ms step_avg:1256.14ms perplexity:10.4805 param_count:43,776,054 | |
step:4901/20000 train_time:6143817ms step_avg:1256.15ms | |
step:5000/20000 val_loss:2.3458 train_time:6268648ms step_avg:1256.24ms perplexity:10.4413 param_count:43,776,054 | |
step:5001/20000 train_time:6269894ms step_avg:1256.24ms | |
step:5100/20000 val_loss:2.3433 train_time:6394713ms step_avg:1256.33ms perplexity:10.4152 param_count:43,776,054 | |
step:5101/20000 train_time:6395961ms step_avg:1256.33ms | |
step:5200/20000 val_loss:2.3411 train_time:6520734ms step_avg:1256.40ms perplexity:10.3929 param_count:43,776,054 | |
step:5201/20000 train_time:6521989ms step_avg:1256.40ms | |
step:5300/20000 val_loss:2.3375 train_time:6646886ms step_avg:1256.50ms perplexity:10.3557 param_count:43,776,054 | |
step:5301/20000 train_time:6648141ms step_avg:1256.50ms | |
step:5400/20000 val_loss:2.3391 train_time:6773174ms step_avg:1256.62ms perplexity:10.3716 param_count:43,776,054 | |
step:5401/20000 train_time:6774422ms step_avg:1256.62ms | |
step:5500/20000 val_loss:2.3375 train_time:6899227ms step_avg:1256.69ms perplexity:10.3552 param_count:43,776,054 | |
step:5501/20000 train_time:6900480ms step_avg:1256.69ms | |
step:5600/20000 val_loss:2.3325 train_time:7025417ms step_avg:1256.78ms perplexity:10.3036 param_count:43,776,054 | |
step:5601/20000 train_time:7026668ms step_avg:1256.78ms | |
step:5700/20000 val_loss:2.3334 train_time:7151468ms step_avg:1256.85ms perplexity:10.3127 param_count:43,776,054 | |
step:5701/20000 train_time:7152747ms step_avg:1256.85ms | |
step:5800/20000 val_loss:2.3297 train_time:7277658ms step_avg:1256.94ms perplexity:10.2751 param_count:43,776,054 | |
step:5801/20000 train_time:7278902ms step_avg:1256.93ms | |
step:5900/20000 val_loss:2.3316 train_time:7403729ms step_avg:1257.00ms perplexity:10.2942 param_count:43,776,054 | |
step:5901/20000 train_time:7404988ms step_avg:1257.00ms | |
step:6000/20000 val_loss:2.3333 train_time:7530014ms step_avg:1257.10ms perplexity:10.3115 param_count:43,776,054 | |
step:6001/20000 train_time:7531261ms step_avg:1257.10ms | |
step:6100/20000 val_loss:2.3265 train_time:7656363ms step_avg:1257.20ms perplexity:10.2417 param_count:43,776,054 | |
step:6101/20000 train_time:7657608ms step_avg:1257.20ms | |
step:6200/20000 val_loss:2.3261 train_time:7782467ms step_avg:1257.26ms perplexity:10.2379 param_count:43,776,054 | |
step:6201/20000 train_time:7783725ms step_avg:1257.26ms | |
step:6300/20000 val_loss:2.3233 train_time:7908658ms step_avg:1257.34ms perplexity:10.2096 param_count:43,776,054 | |
step:6301/20000 train_time:7909907ms step_avg:1257.34ms | |
step:6400/20000 val_loss:2.3238 train_time:8034913ms step_avg:1257.42ms perplexity:10.2144 param_count:43,776,054 | |
step:6401/20000 train_time:8036175ms step_avg:1257.42ms | |
step:6500/20000 val_loss:2.3218 train_time:8161083ms step_avg:1257.49ms perplexity:10.1939 param_count:43,776,054 | |
step:6501/20000 train_time:8162350ms step_avg:1257.49ms | |
step:6600/20000 val_loss:2.3223 train_time:8287001ms step_avg:1257.51ms perplexity:10.1989 param_count:43,776,054 | |
step:6601/20000 train_time:8288272ms step_avg:1257.51ms | |
step:6700/20000 val_loss:2.3201 train_time:8413003ms step_avg:1257.55ms perplexity:10.1764 param_count:43,776,054 | |
step:6701/20000 train_time:8414245ms step_avg:1257.55ms | |
step:6800/20000 val_loss:2.3169 train_time:8539013ms step_avg:1257.59ms perplexity:10.1445 param_count:43,776,054 | |
step:6801/20000 train_time:8540260ms step_avg:1257.59ms | |
step:6900/20000 val_loss:2.3179 train_time:8665013ms step_avg:1257.62ms perplexity:10.1539 param_count:43,776,054 | |
step:6901/20000 train_time:8666257ms step_avg:1257.62ms | |
step:7000/20000 val_loss:2.3155 train_time:8790877ms step_avg:1257.64ms perplexity:10.1302 param_count:43,776,054 | |
step:7001/20000 train_time:8792106ms step_avg:1257.63ms | |
step:7100/20000 val_loss:2.3154 train_time:8916794ms step_avg:1257.66ms perplexity:10.1288 param_count:43,776,054 | |
step:7101/20000 train_time:8918052ms step_avg:1257.66ms | |
step:7200/20000 val_loss:2.3130 train_time:9042938ms step_avg:1257.71ms perplexity:10.1045 param_count:43,776,054 | |
step:7201/20000 train_time:9044183ms step_avg:1257.71ms | |
step:7300/20000 val_loss:2.3113 train_time:9169033ms step_avg:1257.75ms perplexity:10.0873 param_count:43,776,054 | |
step:7301/20000 train_time:9170284ms step_avg:1257.75ms | |
step:7400/20000 val_loss:2.3114 train_time:9295264ms step_avg:1257.82ms perplexity:10.0881 param_count:43,776,054 | |
step:7401/20000 train_time:9296517ms step_avg:1257.82ms | |
step:7500/20000 val_loss:2.3101 train_time:9421561ms step_avg:1257.89ms perplexity:10.0752 param_count:43,776,054 | |
step:7501/20000 train_time:9422811ms step_avg:1257.88ms | |
step:7600/20000 val_loss:2.3102 train_time:9547782ms step_avg:1257.94ms perplexity:10.0760 param_count:43,776,054 | |
step:7601/20000 train_time:9549038ms step_avg:1257.94ms | |
step:7700/20000 val_loss:2.3106 train_time:9674042ms step_avg:1258.00ms perplexity:10.0808 param_count:43,776,054 | |
step:7701/20000 train_time:9675311ms step_avg:1258.00ms | |
step:7800/20000 val_loss:2.3050 train_time:9800418ms step_avg:1258.08ms perplexity:10.0239 param_count:43,776,054 | |
step:7801/20000 train_time:9801673ms step_avg:1258.08ms | |
step:7900/20000 val_loss:2.3049 train_time:9926654ms step_avg:1258.13ms perplexity:10.0229 param_count:43,776,054 | |
step:7901/20000 train_time:9927915ms step_avg:1258.13ms | |
step:8000/20000 val_loss:2.3069 train_time:10052961ms step_avg:1258.19ms perplexity:10.0435 param_count:43,776,054 | |
step:8001/20000 train_time:10054198ms step_avg:1258.19ms | |
step:8100/20000 val_loss:2.3043 train_time:10179305ms step_avg:1258.26ms perplexity:10.0167 param_count:43,776,054 | |
step:8101/20000 train_time:10180543ms step_avg:1258.26ms | |
step:8200/20000 val_loss:2.3023 train_time:10305549ms step_avg:1258.31ms perplexity:9.9967 param_count:43,776,054 | |
step:8201/20000 train_time:10306800ms step_avg:1258.31ms | |
step:8300/20000 val_loss:2.3009 train_time:10431694ms step_avg:1258.35ms perplexity:9.9829 param_count:43,776,054 | |
step:8301/20000 train_time:10432978ms step_avg:1258.35ms | |
step:8400/20000 val_loss:2.3031 train_time:10558009ms step_avg:1258.40ms perplexity:10.0054 param_count:43,776,054 | |
step:8401/20000 train_time:10559291ms step_avg:1258.41ms | |
step:8500/20000 val_loss:2.3018 train_time:10684324ms step_avg:1258.46ms perplexity:9.9919 param_count:43,776,054 | |
step:8501/20000 train_time:10685569ms step_avg:1258.46ms | |
step:8600/20000 val_loss:2.3002 train_time:10810772ms step_avg:1258.53ms perplexity:9.9761 param_count:43,776,054 | |
step:8601/20000 train_time:10812047ms step_avg:1258.53ms | |
step:8700/20000 val_loss:2.2959 train_time:10937131ms step_avg:1258.59ms perplexity:9.9337 param_count:43,776,054 | |
step:8701/20000 train_time:10938393ms step_avg:1258.59ms | |
step:8800/20000 val_loss:2.2953 train_time:11063473ms step_avg:1258.64ms perplexity:9.9271 param_count:43,776,054 | |
step:8801/20000 train_time:11064781ms step_avg:1258.65ms | |
step:8900/20000 val_loss:2.2978 train_time:11190118ms step_avg:1258.73ms perplexity:9.9521 param_count:43,776,054 | |
step:8901/20000 train_time:11191376ms step_avg:1258.73ms | |
step:9000/20000 val_loss:2.2982 train_time:11316298ms step_avg:1258.77ms perplexity:9.9561 param_count:43,776,054 | |
step:9001/20000 train_time:11317516ms step_avg:1258.76ms | |
step:9100/20000 val_loss:2.2941 train_time:11442686ms step_avg:1258.82ms perplexity:9.9152 param_count:43,776,054 | |
step:9101/20000 train_time:11443935ms step_avg:1258.82ms | |
step:9200/20000 val_loss:2.2929 train_time:11569331ms step_avg:1258.90ms perplexity:9.9035 param_count:43,776,054 | |
step:9201/20000 train_time:11570582ms step_avg:1258.90ms | |
step:9300/20000 val_loss:2.2907 train_time:11695781ms step_avg:1258.96ms perplexity:9.8821 param_count:43,776,054 | |
step:9301/20000 train_time:11697053ms step_avg:1258.97ms | |
step:9400/20000 val_loss:2.2923 train_time:11822040ms step_avg:1259.00ms perplexity:9.8972 param_count:43,776,054 | |
step:9401/20000 train_time:11823295ms step_avg:1259.00ms | |
step:9500/20000 val_loss:2.2905 train_time:11948559ms step_avg:1259.07ms perplexity:9.8800 param_count:43,776,054 | |
step:9501/20000 train_time:11949808ms step_avg:1259.07ms | |
step:9600/20000 val_loss:2.2897 train_time:12075047ms step_avg:1259.13ms perplexity:9.8715 param_count:43,776,054 | |
step:9601/20000 train_time:12076294ms step_avg:1259.13ms | |
step:9700/20000 val_loss:2.2879 train_time:12201137ms step_avg:1259.15ms perplexity:9.8539 param_count:43,776,054 | |
step:9701/20000 train_time:12202398ms step_avg:1259.15ms | |
step:9800/20000 val_loss:2.2817 train_time:12327283ms step_avg:1259.17ms perplexity:9.7929 param_count:43,776,054 | |
step:9801/20000 train_time:12328531ms step_avg:1259.17ms | |
step:9900/20000 val_loss:2.2839 train_time:12453496ms step_avg:1259.20ms perplexity:9.8151 param_count:43,776,054 | |
step:9901/20000 train_time:12454746ms step_avg:1259.20ms | |
step:10000/20000 val_loss:2.2859 train_time:12579800ms step_avg:1259.24ms perplexity:9.8342 param_count:43,776,054 | |
step:10001/20000 train_time:12581039ms step_avg:1259.24ms | |
step:10100/20000 val_loss:2.2849 train_time:12706175ms step_avg:1259.28ms perplexity:9.8248 param_count:43,776,054 | |
step:10101/20000 train_time:12707425ms step_avg:1259.28ms | |
step:10200/20000 val_loss:2.2843 train_time:12832941ms step_avg:1259.37ms perplexity:9.8191 param_count:43,776,054 | |
step:10201/20000 train_time:12834203ms step_avg:1259.37ms | |
step:10300/20000 val_loss:2.2840 train_time:12959577ms step_avg:1259.43ms perplexity:9.8160 param_count:43,776,054 | |
step:10301/20000 train_time:12960827ms step_avg:1259.43ms | |
step:10400/20000 val_loss:2.2816 train_time:13086232ms step_avg:1259.50ms perplexity:9.7921 param_count:43,776,054 | |
step:10401/20000 train_time:13087491ms step_avg:1259.50ms | |
step:10500/20000 val_loss:2.2801 train_time:13212885ms step_avg:1259.57ms perplexity:9.7772 param_count:43,776,054 | |
step:10501/20000 train_time:13214160ms step_avg:1259.57ms | |
step:10600/20000 val_loss:2.2805 train_time:13339526ms step_avg:1259.63ms perplexity:9.7814 param_count:43,776,054 | |
step:10601/20000 train_time:13340838ms step_avg:1259.64ms | |
step:10700/20000 val_loss:2.2822 train_time:13466127ms step_avg:1259.69ms perplexity:9.7978 param_count:43,776,054 | |
step:10701/20000 train_time:13467385ms step_avg:1259.69ms | |
step:10800/20000 val_loss:2.2791 train_time:13592530ms step_avg:1259.73ms perplexity:9.7679 param_count:43,776,054 | |
step:10801/20000 train_time:13593815ms step_avg:1259.74ms | |
step:10900/20000 val_loss:2.2767 train_time:13718860ms step_avg:1259.77ms perplexity:9.7448 param_count:43,776,054 | |
step:10901/20000 train_time:13720113ms step_avg:1259.77ms | |
step:11000/20000 val_loss:2.2798 train_time:13845398ms step_avg:1259.82ms perplexity:9.7749 param_count:43,776,054 | |
step:11001/20000 train_time:13846654ms step_avg:1259.82ms | |
step:11100/20000 val_loss:2.2747 train_time:13971864ms step_avg:1259.86ms perplexity:9.7250 param_count:43,776,054 | |
step:11101/20000 train_time:13973138ms step_avg:1259.86ms | |
step:11200/20000 val_loss:2.2772 train_time:14098323ms step_avg:1259.90ms perplexity:9.7492 param_count:43,776,054 | |
step:11201/20000 train_time:14099569ms step_avg:1259.90ms | |
step:11300/20000 val_loss:2.2732 train_time:14224578ms step_avg:1259.93ms perplexity:9.7106 param_count:43,776,054 | |
step:11301/20000 train_time:14225838ms step_avg:1259.93ms | |
step:11400/20000 val_loss:2.2744 train_time:14350976ms step_avg:1259.96ms perplexity:9.7219 param_count:43,776,054 | |
step:11401/20000 train_time:14352226ms step_avg:1259.96ms | |
step:11500/20000 val_loss:2.2703 train_time:14477368ms step_avg:1260.00ms perplexity:9.6827 param_count:43,776,054 | |
step:11501/20000 train_time:14478622ms step_avg:1260.00ms | |
step:11600/20000 val_loss:2.2720 train_time:14604106ms step_avg:1260.06ms perplexity:9.6989 param_count:43,776,054 | |
step:11601/20000 train_time:14605356ms step_avg:1260.06ms | |
step:11700/20000 val_loss:2.2749 train_time:14730805ms step_avg:1260.12ms perplexity:9.7271 param_count:43,776,054 | |
step:11701/20000 train_time:14732070ms step_avg:1260.12ms | |
step:11800/20000 val_loss:2.2692 train_time:14857551ms step_avg:1260.18ms perplexity:9.6714 param_count:43,776,054 | |
step:11801/20000 train_time:14858804ms step_avg:1260.18ms | |
step:11900/20000 val_loss:2.2709 train_time:14984216ms step_avg:1260.24ms perplexity:9.6878 param_count:43,776,054 | |
step:11901/20000 train_time:14985490ms step_avg:1260.24ms | |
step:12000/20000 val_loss:2.2709 train_time:15110972ms step_avg:1260.30ms perplexity:9.6885 param_count:43,776,054 | |
step:12001/20000 train_time:15112201ms step_avg:1260.30ms | |
step:12100/20000 val_loss:2.2685 train_time:15237413ms step_avg:1260.33ms perplexity:9.6654 param_count:43,776,054 | |
step:12101/20000 train_time:15238669ms step_avg:1260.33ms | |
step:12200/20000 val_loss:2.2652 train_time:15363923ms step_avg:1260.37ms perplexity:9.6326 param_count:43,776,054 | |
step:12201/20000 train_time:15365188ms step_avg:1260.37ms | |
step:12300/20000 val_loss:2.2674 train_time:15490331ms step_avg:1260.40ms perplexity:9.6538 param_count:43,776,054 | |
step:12301/20000 train_time:15491572ms step_avg:1260.40ms | |
step:12400/20000 val_loss:2.2673 train_time:15616520ms step_avg:1260.41ms perplexity:9.6535 param_count:43,776,054 | |
step:12401/20000 train_time:15617762ms step_avg:1260.41ms | |
step:12500/20000 val_loss:2.2636 train_time:15743120ms step_avg:1260.46ms perplexity:9.6180 param_count:43,776,054 | |
step:12501/20000 train_time:15744391ms step_avg:1260.46ms | |
step:12600/20000 val_loss:2.2653 train_time:15869872ms step_avg:1260.51ms perplexity:9.6339 param_count:43,776,054 | |
step:12601/20000 train_time:15871137ms step_avg:1260.51ms | |
step:12700/20000 val_loss:2.2677 train_time:15996710ms step_avg:1260.58ms perplexity:9.6568 param_count:43,776,054 | |
step:12701/20000 train_time:15997957ms step_avg:1260.58ms | |
step:12800/20000 val_loss:2.2639 train_time:16123187ms step_avg:1260.61ms perplexity:9.6204 param_count:43,776,054 | |
step:12801/20000 train_time:16124429ms step_avg:1260.61ms | |
step:12900/20000 val_loss:2.2636 train_time:16249506ms step_avg:1260.63ms perplexity:9.6175 param_count:43,776,054 | |
step:12901/20000 train_time:16250781ms step_avg:1260.63ms | |
step:13000/20000 val_loss:2.2627 train_time:16376328ms step_avg:1260.69ms perplexity:9.6089 param_count:43,776,054 | |
step:13001/20000 train_time:16377606ms step_avg:1260.69ms | |
step:13100/20000 val_loss:2.2603 train_time:16503316ms step_avg:1260.76ms perplexity:9.5863 param_count:43,776,054 | |
step:13101/20000 train_time:16504565ms step_avg:1260.76ms | |
step:13200/20000 val_loss:2.2593 train_time:16630123ms step_avg:1260.81ms perplexity:9.5763 param_count:43,776,054 | |
step:13201/20000 train_time:16631397ms step_avg:1260.81ms | |
step:13300/20000 val_loss:2.2595 train_time:16756924ms step_avg:1260.87ms perplexity:9.5782 param_count:43,776,054 | |
step:13301/20000 train_time:16758176ms step_avg:1260.87ms | |
step:13400/20000 val_loss:2.2592 train_time:16884251ms step_avg:1260.96ms perplexity:9.5755 param_count:43,776,054 | |
step:13401/20000 train_time:16885533ms step_avg:1260.96ms | |
step:13500/20000 val_loss:2.2627 train_time:17011269ms step_avg:1261.03ms perplexity:9.6088 param_count:43,776,054 | |
step:13501/20000 train_time:17012552ms step_avg:1261.03ms | |
step:13600/20000 val_loss:2.2585 train_time:17138141ms step_avg:1261.08ms perplexity:9.5690 param_count:43,776,054 | |
step:13601/20000 train_time:17139379ms step_avg:1261.08ms | |
step:13700/20000 val_loss:2.2541 train_time:17265303ms step_avg:1261.16ms perplexity:9.5266 param_count:43,776,054 | |
step:13701/20000 train_time:17266587ms step_avg:1261.16ms | |
step:13800/20000 val_loss:2.2565 train_time:17392088ms step_avg:1261.21ms perplexity:9.5494 param_count:43,776,054 | |
step:13801/20000 train_time:17393347ms step_avg:1261.21ms | |
step:13900/20000 val_loss:2.2553 train_time:17518638ms step_avg:1261.24ms perplexity:9.5383 param_count:43,776,054 | |
step:13901/20000 train_time:17519912ms step_avg:1261.24ms | |
step:14000/20000 val_loss:2.2577 train_time:17645617ms step_avg:1261.30ms perplexity:9.5610 param_count:43,776,054 | |
step:14001/20000 train_time:17646858ms step_avg:1261.30ms | |
step:14100/20000 val_loss:2.2537 train_time:17772356ms step_avg:1261.35ms perplexity:9.5234 param_count:43,776,054 | |
step:14101/20000 train_time:17773627ms step_avg:1261.35ms | |
step:14200/20000 val_loss:2.2562 train_time:17899109ms step_avg:1261.39ms perplexity:9.5468 param_count:43,776,054 | |
step:14201/20000 train_time:17900371ms step_avg:1261.39ms | |
step:14300/20000 val_loss:2.2538 train_time:18025825ms step_avg:1261.43ms perplexity:9.5243 param_count:43,776,054 | |
step:14301/20000 train_time:18027076ms step_avg:1261.43ms | |
step:14400/20000 val_loss:2.2505 train_time:18152779ms step_avg:1261.49ms perplexity:9.4928 param_count:43,776,054 | |
step:14401/20000 train_time:18154050ms step_avg:1261.49ms | |
step:14500/20000 val_loss:2.2523 train_time:18279430ms step_avg:1261.52ms perplexity:9.5100 param_count:43,776,054 | |
step:14501/20000 train_time:18280709ms step_avg:1261.52ms | |
step:14600/20000 val_loss:2.2511 train_time:18406147ms step_avg:1261.56ms perplexity:9.4986 param_count:43,776,054 | |
step:14601/20000 train_time:18407423ms step_avg:1261.56ms | |
step:14700/20000 val_loss:2.2487 train_time:18533068ms step_avg:1261.61ms perplexity:9.4755 param_count:43,776,054 | |
step:14701/20000 train_time:18534334ms step_avg:1261.61ms | |
step:14800/20000 val_loss:2.2513 train_time:18659830ms step_avg:1261.65ms perplexity:9.4996 param_count:43,776,054 | |
step:14801/20000 train_time:18661091ms step_avg:1261.65ms | |
step:14900/20000 val_loss:2.2471 train_time:18786636ms step_avg:1261.69ms perplexity:9.4606 param_count:43,776,054 | |
step:14901/20000 train_time:18787903ms step_avg:1261.70ms | |
step:15000/20000 val_loss:2.2481 train_time:18913325ms step_avg:1261.73ms perplexity:9.4694 param_count:43,776,054 | |
step:15001/20000 train_time:18914583ms step_avg:1261.73ms | |
step:15100/20000 val_loss:2.2496 train_time:19040035ms step_avg:1261.77ms perplexity:9.4837 param_count:43,776,054 | |
step:15101/20000 train_time:19041286ms step_avg:1261.76ms | |
step:15200/20000 val_loss:2.2480 train_time:19166943ms step_avg:1261.81ms perplexity:9.4683 param_count:43,776,054 | |
step:15201/20000 train_time:19168200ms step_avg:1261.81ms | |
step:15300/20000 val_loss:2.2471 train_time:19293758ms step_avg:1261.85ms perplexity:9.4598 param_count:43,776,054 | |
step:15301/20000 train_time:19295000ms step_avg:1261.85ms | |
step:15400/20000 val_loss:2.2485 train_time:19420818ms step_avg:1261.91ms perplexity:9.4739 param_count:43,776,054 | |
step:15401/20000 train_time:19422067ms step_avg:1261.91ms | |
step:15500/20000 val_loss:2.2434 train_time:19547437ms step_avg:1261.94ms perplexity:9.4251 param_count:43,776,054 | |
step:15501/20000 train_time:19548691ms step_avg:1261.94ms | |
step:15600/20000 val_loss:2.2443 train_time:19674168ms step_avg:1261.97ms perplexity:9.4336 param_count:43,776,054 | |
step:15601/20000 train_time:19675419ms step_avg:1261.97ms | |
step:15700/20000 val_loss:2.2453 train_time:19800840ms step_avg:1262.00ms perplexity:9.4436 param_count:43,776,054 | |
step:15701/20000 train_time:19802093ms step_avg:1262.00ms | |
step:15800/20000 val_loss:2.2449 train_time:19927533ms step_avg:1262.04ms perplexity:9.4397 param_count:43,776,054 | |
step:15801/20000 train_time:19928810ms step_avg:1262.04ms | |
step:15900/20000 val_loss:2.2401 train_time:20054331ms step_avg:1262.07ms perplexity:9.3946 param_count:43,776,054 | |
step:15901/20000 train_time:20055586ms step_avg:1262.07ms | |
step:16000/20000 val_loss:2.2418 train_time:20181126ms step_avg:1262.11ms perplexity:9.4104 param_count:43,776,054 | |
step:16001/20000 train_time:20182364ms step_avg:1262.11ms | |
step:16100/20000 val_loss:2.2434 train_time:20308016ms step_avg:1262.15ms perplexity:9.4251 param_count:43,776,054 | |
step:16101/20000 train_time:20309274ms step_avg:1262.15ms | |
step:16200/20000 val_loss:2.2418 train_time:20435148ms step_avg:1262.21ms perplexity:9.4100 param_count:43,776,054 | |
step:16201/20000 train_time:20436437ms step_avg:1262.21ms | |
step:16300/20000 val_loss:2.2442 train_time:20562061ms step_avg:1262.25ms perplexity:9.4330 param_count:43,776,054 | |
step:16301/20000 train_time:20563318ms step_avg:1262.25ms | |
step:16400/20000 val_loss:2.2395 train_time:20689069ms step_avg:1262.30ms perplexity:9.3890 param_count:43,776,054 | |
step:16401/20000 train_time:20690321ms step_avg:1262.30ms | |
step:16500/20000 val_loss:2.2413 train_time:20816039ms step_avg:1262.34ms perplexity:9.4056 param_count:43,776,054 | |
step:16501/20000 train_time:20817295ms step_avg:1262.34ms | |
step:16600/20000 val_loss:2.2382 train_time:20943020ms step_avg:1262.39ms perplexity:9.3769 param_count:43,776,054 | |
step:16601/20000 train_time:20944273ms step_avg:1262.39ms | |
step:16700/20000 val_loss:2.2404 train_time:21069889ms step_avg:1262.43ms perplexity:9.3972 param_count:43,776,054 | |
step:16701/20000 train_time:21071176ms step_avg:1262.43ms | |
step:16800/20000 val_loss:2.2365 train_time:21196770ms step_avg:1262.46ms perplexity:9.3608 param_count:43,776,054 | |
step:16801/20000 train_time:21198041ms step_avg:1262.46ms | |
step:16900/20000 val_loss:2.2384 train_time:21323685ms step_avg:1262.50ms perplexity:9.3781 param_count:43,776,054 | |
step:16901/20000 train_time:21324945ms step_avg:1262.50ms | |
step:17000/20000 val_loss:2.2355 train_time:21450531ms step_avg:1262.54ms perplexity:9.3510 param_count:43,776,054 | |
step:17001/20000 train_time:21451772ms step_avg:1262.54ms | |
step:17100/20000 val_loss:2.2373 train_time:21577151ms step_avg:1262.56ms perplexity:9.3683 param_count:43,776,054 | |
step:17101/20000 train_time:21578416ms step_avg:1262.56ms | |
step:17200/20000 val_loss:2.2345 train_time:21704577ms step_avg:1262.63ms perplexity:9.3416 param_count:43,776,054 | |
step:17201/20000 train_time:21705843ms step_avg:1262.63ms | |
step:17300/20000 val_loss:2.2355 train_time:21831430ms step_avg:1262.66ms perplexity:9.3514 param_count:43,776,054 | |
step:17301/20000 train_time:21832700ms step_avg:1262.66ms | |
step:17400/20000 val_loss:2.2358 train_time:21958455ms step_avg:1262.71ms perplexity:9.3539 param_count:43,776,054 | |
step:17401/20000 train_time:21959713ms step_avg:1262.71ms | |
step:17500/20000 val_loss:2.2366 train_time:22085603ms step_avg:1262.76ms perplexity:9.3618 param_count:43,776,054 | |
step:17501/20000 train_time:22086848ms step_avg:1262.76ms | |
step:17600/20000 val_loss:2.2311 train_time:22213025ms step_avg:1262.82ms perplexity:9.3100 param_count:43,776,054 | |
step:17601/20000 train_time:22214294ms step_avg:1262.82ms | |
step:17700/20000 val_loss:2.2390 train_time:22340201ms step_avg:1262.87ms perplexity:9.3835 param_count:43,776,054 | |
step:17701/20000 train_time:22341454ms step_avg:1262.87ms | |
step:17800/20000 val_loss:2.2308 train_time:22467273ms step_avg:1262.92ms perplexity:9.3074 param_count:43,776,054 | |
step:17801/20000 train_time:22468527ms step_avg:1262.92ms | |
step:17900/20000 val_loss:2.2288 train_time:22594601ms step_avg:1262.97ms perplexity:9.2888 param_count:43,776,054 | |
step:17901/20000 train_time:22595864ms step_avg:1262.97ms | |
step:18000/20000 val_loss:2.2332 train_time:22721603ms step_avg:1263.01ms perplexity:9.3293 param_count:43,776,054 | |
step:18001/20000 train_time:22722831ms step_avg:1263.01ms | |
step:18100/20000 val_loss:2.2311 train_time:22848475ms step_avg:1263.04ms perplexity:9.3102 param_count:43,776,054 | |
step:18101/20000 train_time:22849741ms step_avg:1263.04ms | |
step:18200/20000 val_loss:2.2263 train_time:22975436ms step_avg:1263.08ms perplexity:9.2651 param_count:43,776,054 | |
step:18201/20000 train_time:22976695ms step_avg:1263.08ms | |
step:18300/20000 val_loss:2.2271 train_time:23102483ms step_avg:1263.12ms perplexity:9.2727 param_count:43,776,054 | |
step:18301/20000 train_time:23103772ms step_avg:1263.12ms | |
step:18400/20000 val_loss:2.2318 train_time:23229600ms step_avg:1263.16ms perplexity:9.3163 param_count:43,776,054 | |
step:18401/20000 train_time:23230849ms step_avg:1263.16ms | |
step:18500/20000 val_loss:2.2259 train_time:23356748ms step_avg:1263.21ms perplexity:9.2622 param_count:43,776,054 | |
step:18501/20000 train_time:23358007ms step_avg:1263.21ms | |
step:18600/20000 val_loss:2.2256 train_time:23483851ms step_avg:1263.25ms perplexity:9.2591 param_count:43,776,054 | |
step:18601/20000 train_time:23485119ms step_avg:1263.25ms | |
step:18700/20000 val_loss:2.2267 train_time:23611037ms step_avg:1263.30ms perplexity:9.2692 param_count:43,776,054 | |
step:18701/20000 train_time:23612301ms step_avg:1263.30ms | |
step:18800/20000 val_loss:2.2266 train_time:23738082ms step_avg:1263.34ms perplexity:9.2681 param_count:43,776,054 | |
step:18801/20000 train_time:23739334ms step_avg:1263.34ms | |
step:18900/20000 val_loss:2.2252 train_time:23865265ms step_avg:1263.38ms perplexity:9.2551 param_count:43,776,054 | |
step:18901/20000 train_time:23866518ms step_avg:1263.38ms | |
step:19000/20000 val_loss:2.2251 train_time:23992385ms step_avg:1263.42ms perplexity:9.2540 param_count:43,776,054 | |
step:19001/20000 train_time:23993603ms step_avg:1263.42ms | |
step:19100/20000 val_loss:2.2259 train_time:24119415ms step_avg:1263.46ms perplexity:9.2620 param_count:43,776,054 | |
step:19101/20000 train_time:24120684ms step_avg:1263.46ms | |
step:19200/20000 val_loss:2.2203 train_time:24246711ms step_avg:1263.51ms perplexity:9.2104 param_count:43,776,054 | |
step:19201/20000 train_time:24247976ms step_avg:1263.51ms | |
step:19300/20000 val_loss:2.2143 train_time:24373791ms step_avg:1263.55ms perplexity:9.1552 param_count:43,776,054 | |
step:19301/20000 train_time:24375039ms step_avg:1263.54ms | |
step:19400/20000 val_loss:2.2145 train_time:24500657ms step_avg:1263.57ms perplexity:9.1571 param_count:43,776,054 | |
step:19401/20000 train_time:24501907ms step_avg:1263.57ms | |
step:19500/20000 val_loss:2.2097 train_time:24627547ms step_avg:1263.60ms perplexity:9.1130 param_count:43,776,054 | |
step:19501/20000 train_time:24628804ms step_avg:1263.60ms | |
step:19600/20000 val_loss:2.2035 train_time:24754633ms step_avg:1263.64ms perplexity:9.0571 param_count:43,776,054 | |
step:19601/20000 train_time:24755898ms step_avg:1263.64ms | |
step:19700/20000 val_loss:2.2032 train_time:24881720ms step_avg:1263.67ms perplexity:9.0539 param_count:43,776,054 | |
step:19701/20000 train_time:24882979ms step_avg:1263.67ms | |
step:19800/20000 val_loss:2.1960 train_time:25008841ms step_avg:1263.71ms perplexity:8.9888 param_count:43,776,054 | |
step:19801/20000 train_time:25010119ms step_avg:1263.71ms | |
step:19900/20000 val_loss:2.1929 train_time:25136142ms step_avg:1263.76ms perplexity:8.9615 param_count:43,776,054 | |
step:19901/20000 train_time:25137434ms step_avg:1263.76ms | |
step:20000/20000 val_loss:2.1906 train_time:25263508ms step_avg:1263.81ms perplexity:8.9404 param_count:43,776,054 | |
peak memory consumption training: 17 GiB |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment