Last active
January 21, 2025 06:21
-
-
Save YouJiacheng/04a0cacd2777292430b2d68d594bb450 to your computer and use it in GitHub Desktop.
3.17 minutes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
from typing import override | |
with open(sys.argv[0]) as f: | |
code = f.read() # read the code of this file ASAP, for logging | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
import contextlib | |
import time | |
import uuid | |
from dataclasses import dataclass | |
from pathlib import Path | |
import torch | |
import torch._inductor.config as config | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from torch import Tensor, nn | |
# Use of FlexAttention contributed by @KoszarskyB | |
from torch.nn.attention.flex_attention import BlockMask, flex_attention | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
config.coordinate_descent_tuning = True | |
# ----------------------------------------------------------------------------- | |
# Custom operators | |
@torch.library.custom_op("nanogpt::mm", mutates_args=()) | |
def mm_op( | |
x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float | |
) -> tuple[Tensor, Tensor, Tensor]: | |
@torch.compile | |
def impl(x: Tensor, w: Tensor): | |
assert x.is_contiguous() and w.is_contiguous() | |
x_f8 = x.mul(x_s).to(torch.float8_e4m3fn) | |
w_f8 = w.mul(w_s).to(torch.float8_e4m3fn) | |
out = torch._scaled_mm( | |
x_f8, | |
w_f8.t(), | |
out_dtype=torch.bfloat16, | |
scale_a=x.new_tensor(1 / x_s, dtype=torch.float32), | |
scale_b=x.new_tensor(1 / w_s, dtype=torch.float32), | |
use_fast_accum=True, | |
) | |
return out, x_f8, w_f8 | |
return impl(x, w) | |
@mm_op.register_fake | |
def _(x: Tensor, w: Tensor, *_): | |
assert x.ndim == w.ndim == 2 | |
assert x.shape[1] == w.shape[1] | |
assert x.device == w.device | |
assert x.is_contiguous() and w.is_contiguous() | |
return x @ w.t(), x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) | |
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) | |
def mm_backward_op( | |
g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float | |
) -> tuple[Tensor, Tensor]: | |
@torch.compile | |
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): | |
assert grad.is_contiguous() | |
x_inv_s = grad.new_tensor(1 / x_s, dtype=torch.float32) | |
w_inv_s = grad.new_tensor(1 / w_s, dtype=torch.float32) | |
grad_inv_s = grad.new_tensor(1 / grad_s, dtype=torch.float32) | |
grad_f8 = grad.mul(grad_s).to(torch.float8_e5m2) | |
grad_x = torch._scaled_mm( | |
grad_f8, | |
w_f8.t().contiguous().t(), | |
out_dtype=torch.bfloat16, | |
scale_a=grad_inv_s, | |
scale_b=w_inv_s, | |
use_fast_accum=False, | |
) | |
# faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) | |
grad_w = torch._scaled_mm( | |
x_f8.t().contiguous(), | |
grad_f8.t().contiguous().t(), | |
out_dtype=torch.float32, | |
scale_a=x_inv_s, | |
scale_b=grad_inv_s, | |
use_fast_accum=False, | |
).t() | |
return grad_x, grad_w | |
return impl(g, x_f8, w_f8) | |
@mm_backward_op.register_fake | |
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): | |
return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) | |
def backward(ctx, grad_out: Tensor, *_): | |
x_f8, w_f8 = ctx.saved_tensors | |
x_s, w_s, grad_s = ctx.scales | |
grad_x, grad_w = torch.ops.nanogpt.mm_backward( | |
grad_out, x_f8, w_f8, x_s, w_s, grad_s | |
) | |
return grad_x, grad_w, None, None, None | |
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): | |
*_, x_s, w_s, grad_s = inputs | |
_, x_f8, w_f8 = output | |
ctx.save_for_backward(x_f8, w_f8) | |
ctx.scales = x_s, w_s, grad_s | |
ctx.set_materialize_grads(False) | |
mm_op.register_autograd(backward, setup_context=setup_context) | |
def lm_head(x: Tensor, w: Tensor): | |
_x = x.flatten(0, -2) | |
out: Tensor = torch.ops.nanogpt.mm(_x, w, x_s=2.0, w_s=32.0, grad_s=2.0**29)[0] | |
return out.reshape(*x.shape[:-1], -1) | |
# ----------------------------------------------------------------------------- | |
# Muon optimizer | |
@torch.compile | |
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: | |
""" | |
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | |
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | |
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | |
zero even beyond the point where the iteration no longer converges all the way to one everywhere | |
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | |
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | |
performance at all relative to UV^T, where USV^T = G is the SVD. | |
""" | |
assert len(G.shape) == 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
if G.size(0) > G.size(1): | |
X = X.T | |
# # Ensure spectral norm is at most 1 | |
X = X / (X.norm() + 1e-7) | |
# Perform the NS iterations | |
for _ in range(steps): | |
A = X @ X.T | |
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng | |
X = a * X + B @ X | |
if G.size(0) > G.size(1): | |
X = X.T | |
return X | |
class Muon(torch.optim.Optimizer): | |
""" | |
Muon - MomentUm Orthogonalized by Newton-schulz | |
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- | |
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal | |
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has | |
the advantage that it can be stably run in bfloat16 on the GPU. | |
Some warnings: | |
- This optimizer assumes that all parameters passed in are 2D. | |
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D | |
parameters; those should all be optimized by a standard method (e.g., AdamW). | |
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. | |
- We believe it is unlikely to work well for training with small batch size. | |
- We believe it may not work well for finetuning pretrained models, but we haven't tested this. | |
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). | |
Arguments: | |
lr: The learning rate used by the internal SGD. | |
momentum: The momentum used by the internal SGD. | |
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) | |
ns_steps: The number of Newton-Schulz iteration steps to use. | |
""" | |
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): | |
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) | |
params: "list[Tensor]" = [*params] | |
assert all(isinstance(p, Tensor) for p in params) | |
sizes = {p.numel() for p in params} | |
def create_update_buffer(size: int): | |
b = torch.empty(world_size, size, dtype=torch.bfloat16, device="cuda") | |
return dict(update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)]) | |
param_groups = [ | |
dict(params=[p for p in params if p.numel() == size], **create_update_buffer(size)) for size in sizes] | |
super().__init__(param_groups, defaults) | |
def step(self): | |
for group in self.param_groups: | |
lr = group['lr'] | |
momentum = group['momentum'] | |
nesterov = group['nesterov'] | |
ns_steps = group['ns_steps'] | |
update_buffer = group['update_buffer'] | |
update_buffer_views: "list[Tensor]" = group['update_buffer_views'] | |
# generate weight updates in distributed fashion | |
params: "list[Tensor]" = group['params'] | |
handle = None | |
params_world = None | |
def update_prev(): | |
if params_world is None: | |
return | |
assert handle is not None | |
handle.wait() | |
for p_world, g_world in zip(params_world, update_buffer_views): | |
p_world.data.add_( | |
g_world.view_as(p_world), | |
alpha=-lr * max(1, p_world.size(0) / p_world.size(1)) ** 0.5, | |
) | |
for base_i in range(len(params))[::world_size]: | |
if base_i + rank < len(params): | |
p = params[base_i + rank] | |
g = p.grad | |
assert g is not None | |
state = self.state[p] | |
if 'momentum_buffer' not in state: | |
state['momentum_buffer'] = torch.zeros_like(g) | |
buf: Tensor = state['momentum_buffer'] | |
buf.lerp_(g, 1 - momentum) | |
g = g.lerp_(buf, momentum) if nesterov else buf | |
g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten() | |
else: | |
g = update_buffer_views[rank] | |
update_prev() # async all_gather instead of sync all_reduce by @YouJiacheng | |
handle = dist.all_gather_into_tensor(update_buffer, g, async_op=True) | |
params_world = params[base_i : base_i + world_size] | |
update_prev() | |
# ----------------------------------------------------------------------------- | |
# PyTorch nn.Module definitions for the GPT-2 model | |
def norm(x: Tensor, size: int = None): | |
if size is None: | |
size = x.size(-1) | |
return F.rms_norm(x.unflatten(-1, (-1, size)), (size,)).flatten(-2) | |
class CastedLinear(nn.Linear): | |
def __init__(self, in_features: int, out_features: int): | |
super().__init__(in_features, out_features, bias=False) | |
@override | |
def reset_parameters(self) -> None: | |
# leave for future tuning | |
# 0.5 is a bit better than the default 1/sqrt(3) | |
std = 0.5 * (self.in_features ** -0.5) | |
bound = (3 ** 0.5) * std | |
with torch.no_grad(): | |
self.weight.uniform_(-bound, bound) | |
def forward(self, x): | |
return F.linear(x, self.weight.type_as(x)) | |
class Rotary(nn.Module): | |
def __init__(self, dim: int, max_seq_len=65536): | |
super().__init__() | |
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning) | |
angular_freq = (1 / 1024) ** torch.linspace(0.0, 1.0, steps=dim // 4, dtype=torch.float32) | |
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim // 4)]) | |
t = torch.arange(max_seq_len, dtype=torch.float32) | |
theta = torch.einsum("i, j -> ij", t, angular_freq) | |
self.cos = nn.Buffer(theta.cos(), persistent=False) | |
self.sin = nn.Buffer(theta.sin(), persistent=False) | |
def forward(self, x: Tensor): | |
cos, sin = self.cos[None, :x.size(-3), None, :], self.sin[None, :x.size(-3), None, :] | |
x1, x2 = x.to(dtype=torch.float32).chunk(2, dim=-1) | |
y1 = x1 * cos + x2 * sin | |
y2 = x1 * (-sin) + x2 * cos | |
return torch.cat((y1, y2), 3).type_as(x) | |
class CausalSelfAttention(nn.Module): | |
def __init__(self, dim: int, num_heads: int): | |
super().__init__() | |
assert dim % num_heads == 0 | |
self.num_heads = num_heads | |
self.c_q = CastedLinear(dim, dim) | |
self.c_k = CastedLinear(dim, dim) | |
self.c_v = CastedLinear(dim, dim) | |
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) | |
self.rotary = Rotary(dim // num_heads) # dim // num_heads = head_dim | |
self.c_proj = CastedLinear(dim, dim) | |
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 | |
def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask): | |
B, T = x.size(0), x.size(1) # batch size, sequence length | |
assert B == 1, 'Must use batch size = 1 for FlexAttention' | |
q = self.c_q(x).view(B, T, self.num_heads, -1) | |
k = self.c_k(x).view(B, T, self.num_heads, -1) | |
v = self.c_v(x).view(B, T, self.num_heads, -1) | |
if ve is None: # skip mid-layers token value embeddings by @YouJiacheng | |
v = self.lambdas[0] * v | |
else: | |
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 | |
q, k = norm(q), norm(k) # QK norm @Grad62304977 | |
q, k = self.rotary(q), self.rotary(k) | |
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask) | |
y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side | |
y = self.c_proj(y) | |
return y | |
class MLP(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.c_fc = CastedLinear(dim, 4 * dim) | |
self.c_proj = CastedLinear(4 * dim, dim) | |
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 | |
def forward(self, x): | |
x = self.c_fc(x) | |
x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 | |
x = self.c_proj(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, model_dim: int, num_heads: int, layer_idx: int): | |
super().__init__() | |
# skip attention of blocks.7 (the 8th layer) by @YouJiacheng | |
self.attn = CausalSelfAttention(model_dim, num_heads) if layer_idx != 7 else None | |
self.mlp = MLP(model_dim) | |
self.lambdas = nn.Parameter(torch.tensor([1., 0.])) | |
self.layer_idx = layer_idx | |
def forward(self, x, vi, x0, block_mask): | |
x = self.lambdas[0] * x + self.lambdas[1] * x0 | |
if self.attn is not None: | |
x = x + self.attn(norm(x), vi, block_mask) | |
x = x + self.mlp(norm(x)) | |
return x | |
class ValueEmbedding(nn.Module): | |
def __init__(self, num_embeddings: int, embedding_dim: int): | |
super().__init__() | |
self.__setattr__ | |
self.embed = nn.ModuleList([nn.Embedding(num_embeddings, embedding_dim) for _ in range(3)]) | |
def forward(self, inputs) -> "list[Tensor | None]": | |
ve = [emb(inputs) for emb in self.embed] | |
# 012 ... 012 structure on token value embeddings by @YouJiacheng, improved upon @leloykun's U-net structure | |
ve = [ | |
ve[0], ve[1], ve[2], None, None, None, | |
None, None, None, ve[0], ve[1], ve[2], | |
] | |
return ve | |
# ----------------------------------------------------------------------------- | |
# The main GPT-2 model | |
def next_multiple_of_128(v: float | int): | |
n = 128 | |
return next(x for x in range(int(v) + 1 + n)[::n] if x >= v) | |
class GPT(nn.Module): | |
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int): | |
super().__init__() | |
self.embed = nn.Embedding(vocab_size, model_dim) | |
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning | |
self.value_embeds = ValueEmbedding(vocab_size, model_dim) | |
self.blocks = nn.ModuleList([Block(model_dim, num_heads, layer_idx) for layer_idx in range(num_layers)]) | |
# U-net design by @brendanh0gan | |
self.num_encoder_layers = num_layers // 2 # Half of the layers for encoder | |
self.num_decoder_layers = num_layers - self.num_encoder_layers # Remaining for decoder | |
# Add learnable skip connection weights for decoder layers | |
self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) | |
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. | |
# suggested to me by @Grad62304977. this originates from Karpathy's experiments. | |
self.lm_head = CastedLinear(model_dim, next_multiple_of_128(vocab_size)) | |
self.lm_head.weight.data.zero_() # @Grad62304977 | |
def forward( | |
self, | |
inputs: Tensor, | |
targets: Tensor, | |
sliding_window_num_blocks: Tensor, | |
): | |
BLOCK_SIZE = 128 | |
assert inputs.ndim == 1 | |
assert len(inputs) % BLOCK_SIZE == 0 | |
NUM_BLOCKS = len(inputs) // BLOCK_SIZE | |
docs = (inputs == 50256).cumsum(0) | |
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() | |
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() | |
def document_causal(b, h, q_idx, kv_idx): | |
causal_mask = q_idx >= kv_idx | |
document_mask = docs[q_idx] == docs[kv_idx] | |
return causal_mask & document_mask | |
def dense_to_ordered(dense_mask: Tensor): | |
num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) | |
indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(torch.int32) | |
return num_blocks[None, None].contiguous(), indices[None, None].contiguous() | |
# manual block mask creation by @YouJiacheng | |
def create_doc_swc_block_mask(sliding_window_num_blocks: Tensor): | |
kv_idx = block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") | |
q_idx = block_idx[:, None] | |
causal_bm = q_idx >= kv_idx | |
causal_full_bm = q_idx > kv_idx | |
window_bm = q_idx - kv_idx < sliding_window_num_blocks | |
window_full_bm = window_bm # block-wise sliding window by @YouJiacheng | |
# document_bm = (docs_low[q_idx] <= docs_high[kv_idx]) & (docs_low[kv_idx] <= docs_high[q_idx]) | |
document_bm = (docs_low[:, None] <= docs_high) & (docs_low <= docs_high[:, None]) | |
document_full_bm = (docs_low[:, None] == docs_high) & (docs_low == docs_high[:, None]) | |
nonzero_bm = causal_bm & window_bm & document_bm | |
full_bm = causal_full_bm & window_full_bm & document_full_bm | |
kv_num_blocks, kv_indices = dense_to_ordered(nonzero_bm & ~full_bm) | |
full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) | |
return BlockMask.from_kv_blocks( | |
kv_num_blocks, | |
kv_indices, | |
full_kv_num_blocks, | |
full_kv_indices, | |
BLOCK_SIZE=BLOCK_SIZE, | |
mask_mod=document_causal, | |
) | |
block_mask = create_doc_swc_block_mask(sliding_window_num_blocks) | |
# forward the GPT model itself | |
x = self.embed(inputs)[None] | |
x = norm(x) # @Grad62304977 | |
x0 = x | |
ve = self.value_embeds(inputs) | |
ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:] | |
# Store outputs for U-Net skip connections | |
skip_connections = [] | |
# Encoder pass - process only the first half of the blocks | |
for i in range(self.num_encoder_layers): | |
x = self.blocks[i](x, ve_enc[i], x0, block_mask) | |
skip_connections.append(x) | |
# Decoder pass - process the remaining blocks with weighted skip connections | |
for i in range(self.num_decoder_layers): | |
x = x + self.skip_weights[i] * skip_connections.pop() | |
# U-net structure on token value embeddings by @leloykun | |
x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_mask) | |
x = norm(x) | |
if self.training: | |
logits = lm_head(x, self.lm_head.weight) | |
else: | |
logits = self.lm_head(x) | |
logits = 30 * torch.sigmoid(logits.float() / 7.5) # @Grad62304977 | |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
return loss | |
# ----------------------------------------------------------------------------- | |
# Our own simple Distributed Data Loader | |
def _load_data_shard(file: Path): | |
# only reads the header, returns header data | |
# header is 256 int32 | |
header = torch.from_file(f"{file}", False, 256, dtype=torch.int32) | |
assert header[0] == 20240520, 'magic number mismatch in the data .bin file' | |
assert header[1] == 1, 'unsupported version' | |
num_tokens = int(header[2]) # number of tokens (claimed) | |
with file.open('rb', buffering=0) as f: | |
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng | |
f.seek(256 * 4) | |
nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng | |
assert nbytes == 2 * num_tokens, 'number of tokens read does not match header?' | |
return tokens | |
class DistributedDataLoader: | |
def __init__(self, filename_pattern: str, batch_size: int): | |
assert batch_size % world_size == 0 | |
self.files = sorted(Path.cwd().glob(filename_pattern)) | |
self.batch_size = batch_size | |
self.reset() | |
def reset(self): | |
self.next_shard = 0 | |
self.advance() | |
def advance(self): # advance to next data shard | |
self.pos = 0 | |
self.tokens = _load_data_shard(self.files[self.next_shard]) | |
self.next_shard = (self.next_shard + 1) % len(self.files) | |
def next_batch(self): | |
local_batch_size = self.batch_size // world_size | |
buf = self.tokens[self.pos + rank * local_batch_size:][:local_batch_size + 1] | |
# by @YouJiacheng: host side async is sufficient; | |
# no performance improvement was observed when introducing a separate stream. | |
inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # inputs | |
targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # targets | |
# advance current position and load next shard if necessary | |
self.pos += self.batch_size | |
if self.pos + self.batch_size + 1 >= len(self.tokens): | |
self.advance() | |
return inputs, targets | |
# ----------------------------------------------------------------------------- | |
# int main | |
@dataclass | |
class Hyperparameters: | |
# data | |
train_bin = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on | |
val_bin = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on | |
val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons | |
# optimization | |
batch_size = 8*64*1024 # batch size in tokens | |
num_iterations = 1395 # number of iterations to run | |
cooldown_frac = 0.4 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule | |
# evaluation and logging | |
val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end | |
# implementation | |
save_checkpoint = False | |
args = Hyperparameters() | |
# set up DDP (distributed data parallel). torchrun sets this env variable | |
rank = int(os.environ['RANK']) | |
local_rank = int(os.environ['LOCAL_RANK']) | |
world_size = int(os.environ['WORLD_SIZE']) | |
assert torch.cuda.is_available() | |
device = torch.device('cuda', local_rank) | |
torch.cuda.set_device(device) | |
dist.init_process_group(backend='nccl', device_id=device) | |
dist.barrier() | |
master_process = (rank == 0) # this process will do logging, checkpointing etc. | |
# begin logging | |
def print0(s, console=False): ... | |
if master_process: | |
run_id = uuid.uuid4() | |
(logs_dir := Path("logs")).mkdir(exist_ok=True) | |
logfile = logs_dir / f"{run_id}.txt" | |
print(logfile.stem) | |
def print0(s, console=False): | |
with logfile.open("a") as f: | |
# if console: | |
# print(s) | |
print(s, file=f) | |
# begin by printing this file (the Python code) | |
print0(code) | |
print0('='*100) | |
# log information about the hardware/software environment this is running on | |
print0(f'Running Python {sys.version}') | |
print0(f'Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}') | |
def nvidia_smi(): | |
import subprocess # avoid top level import | |
return subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout | |
print0(nvidia_smi()) | |
print0('='*100) | |
# load data | |
train_loader = DistributedDataLoader(args.train_bin, args.batch_size) | |
val_loader = DistributedDataLoader(args.val_bin, args.batch_size) | |
print0(f'Training dataloader files: {train_loader.files}') | |
print0(f'Validation dataloader files: {val_loader.files}') | |
print0('='*100) | |
inputs_train, targets_train = train_loader.next_batch() | |
model = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768) | |
model = model.cuda() | |
for m in model.modules(): | |
if isinstance(m, nn.Embedding): | |
m.bfloat16() | |
model: nn.Module = torch.compile(model) | |
ddp_model = DDP(model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) | |
sliding_window_num_blocks = torch.tensor(1, dtype=torch.int32, device="cuda") | |
sw_num_blocks_prev = 1 | |
# collect the parameters to optimize | |
hidden_matrix_params = [p for p in model.blocks.parameters() if p.ndim == 2] | |
embed_params = [model.embed.weight, *model.value_embeds.parameters()] | |
scalar_params = [p for p in model.parameters() if p.ndim < 2] | |
head_params = [model.lm_head.weight] | |
# init the optimizer(s) | |
optimizer1 = torch.optim.Adam([ | |
dict(params=head_params, lr=0.008), | |
dict(params=embed_params, lr=0.6), | |
dict(params=scalar_params, lr=0.04), | |
], | |
betas=(0.8, 0.95), | |
fused=True, | |
) | |
optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95) | |
optimizers = [optimizer1, optimizer2] | |
# learning rate decay scheduler (stable then decay) | |
# learning rate schedule: stable then decay | |
def get_lr(it): | |
t = 1 - it / args.num_iterations # time remaining in training | |
assert 1 >= t >= 0 | |
w = min(t / args.cooldown_frac, 1.0) # 1 -> 0 | |
return w * 1.0 + (1 - w) * 0.1 | |
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] | |
# Start training loop | |
training_time_ms = 0 | |
# start the clock | |
torch.cuda.synchronize() | |
t0 = time.perf_counter() | |
# begin training | |
train_steps = args.num_iterations | |
for step in range(train_steps + 1): | |
last_step = (step == train_steps) | |
# This effectively ignores timing first 10 steps, which are slower for weird reasons. | |
# Alternately, and slightly more correctly in terms of benchmarking, we could do 10 | |
# steps with dummy data first, and then re-initialize the model and reset the loader. | |
if step == 10: | |
training_time_ms = 0 | |
t0 = time.perf_counter() | |
timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val | |
# Linearly increase the sliding window size over training in chunks of 128 from 128 -> 1792. By @fernbear.bsky.social | |
frac_done = step / train_steps # training progress | |
sw_num_blocks = next_multiple_of_128(max(1, 1728 * frac_done)) // 128 | |
if sw_num_blocks != sw_num_blocks_prev: | |
sliding_window_num_blocks.fill_(sw_num_blocks) | |
sw_num_blocks_prev = sw_num_blocks | |
# --------------- VALIDATION SECTION ----------------- | |
if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): | |
# stop the clock | |
torch.cuda.synchronize() | |
training_time_ms += 1000 * (time.perf_counter() - t0) | |
# run validation batches | |
ddp_model.eval() | |
val_loader.reset() | |
val_loss = 0.0 | |
# calculate the number of steps to take in the val loop. | |
assert args.val_tokens % args.batch_size == 0 | |
val_steps = args.val_tokens // args.batch_size | |
for _ in range(val_steps): | |
with torch.no_grad(): | |
inputs_val, targets_val = val_loader.next_batch() | |
val_loss += ddp_model(inputs_val, targets_val, sliding_window_num_blocks) | |
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) | |
val_loss /= val_steps | |
# logging | |
print0(f'step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms', console=True) | |
ddp_model.train() | |
def abs_cdf_diff(t: Tensor, thresholds: list[float]): | |
t = t.abs() | |
level = torch.bucketize(t, t.new_tensor(thresholds), out_int32=True) # sum(x > v for v in thresholds) | |
cdf_diff = level.flatten().bincount(minlength=len(thresholds) + 1) / t.numel() | |
return cdf_diff.tolist() | |
print0(f"{abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=}") | |
print0(f"{model.lm_head.weight.data.max().item()=}") | |
# start the clock again | |
torch.cuda.synchronize() | |
t0 = time.perf_counter() | |
if last_step: | |
if master_process and args.save_checkpoint: | |
log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) | |
os.makedirs(f'logs/{run_id}', exist_ok=True) | |
torch.save(log, f'logs/{run_id}/state_step{step:06d}.pt') | |
# the last step only has the validation loop, so break to avoid training | |
break | |
# --------------- TRAINING SECTION BEGIN ----------------- | |
with contextlib.ExitStack() as stack: | |
if step >= 5: | |
stack.enter_context(torch.compiler.set_stance(skip_guard_eval_unsafe=True)) | |
ddp_model(inputs_train, targets_train, sliding_window_num_blocks).backward() | |
inputs_train, targets_train = train_loader.next_batch() | |
# momentum warmup for Muon | |
frac = min(step / 300, 1) | |
for group in optimizer2.param_groups: | |
group['momentum'] = (1 - frac) * 0.85 + frac * 0.95 | |
# step the optimizers and schedulers | |
for opt, sched in zip(optimizers, schedulers): | |
opt.step() | |
sched.step() | |
# null the gradients | |
ddp_model.zero_grad(set_to_none=True) | |
# --------------- TRAINING SECTION END ------------------- | |
# everything that follows now is just diagnostics, prints, logging, etc. | |
approx_time = training_time_ms + 1000 * (time.perf_counter() - t0) | |
print0(f'step:{step+1}/{train_steps} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms', console=True) | |
print0( | |
f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " | |
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" | |
) | |
dist.destroy_process_group() | |
==================================================================================================== | |
Running Python 3.12.8 (main, Dec 19 2024, 14:33:20) [Clang 18.1.8 ] | |
Running PyTorch 2.7.0.dev20250110+cu126 compiled for CUDA 12.6 | |
Mon Jan 13 13:02:34 2025 | |
+---------------------------------------------------------------------------------------+ | |
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.6 | | |
|-----------------------------------------+----------------------+----------------------+ | |
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | |
| | | MIG M. | | |
|=========================================+======================+======================| | |
| 0 NVIDIA H100 80GB HBM3 On | 00000000:65:02.0 Off | 0 | | |
| N/A 37C P0 119W / 700W | 7092MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 1 NVIDIA H100 80GB HBM3 On | 00000000:67:02.0 Off | 0 | | |
| N/A 44C P0 128W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 2 NVIDIA H100 80GB HBM3 On | 00000000:69:02.0 Off | 0 | | |
| N/A 44C P0 122W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 3 NVIDIA H100 80GB HBM3 On | 00000000:6B:02.0 Off | 0 | | |
| N/A 38C P0 118W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 4 NVIDIA H100 80GB HBM3 On | 00000000:6F:02.0 Off | 0 | | |
| N/A 38C P0 117W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 5 NVIDIA H100 80GB HBM3 On | 00000000:71:02.0 Off | 0 | | |
| N/A 44C P0 121W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 6 NVIDIA H100 80GB HBM3 On | 00000000:73:02.0 Off | 0 | | |
| N/A 45C P0 126W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 7 NVIDIA H100 80GB HBM3 On | 00000000:75:02.0 Off | 0 | | |
| N/A 38C P0 123W / 700W | 3219MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
+---------------------------------------------------------------------------------------+ | |
| Processes: | | |
| GPU GI CI PID Type Process name GPU Memory | | |
| ID ID Usage | | |
|=======================================================================================| | |
+---------------------------------------------------------------------------------------+ | |
==================================================================================================== | |
Training dataloader files: [PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000001.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000002.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000003.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000004.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000005.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000006.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000007.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000008.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000009.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000010.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000011.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000012.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000013.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000014.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000015.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000016.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000017.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000018.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000019.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000020.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000021.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000022.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000023.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000024.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000025.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000026.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000027.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000028.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000029.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000030.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000031.bin'), PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_train_000032.bin')] | |
Validation dataloader files: [PosixPath('/root/modded-nanogpt/data/fineweb10B/fineweb_val_000000.bin')] | |
==================================================================================================== | |
step:0/1395 val_loss:10.8258 train_time:0ms step_avg:nanms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] | |
model.lm_head.weight.data.max().item()=0.0 | |
step:1/1395 train_time:23694ms step_avg:nanms | |
step:2/1395 train_time:23756ms step_avg:nanms | |
step:3/1395 train_time:23923ms step_avg:nanms | |
step:4/1395 train_time:24047ms step_avg:nanms | |
step:5/1395 train_time:24170ms step_avg:nanms | |
step:6/1395 train_time:24295ms step_avg:nanms | |
step:7/1395 train_time:24418ms step_avg:nanms | |
step:8/1395 train_time:24542ms step_avg:nanms | |
step:9/1395 train_time:24666ms step_avg:nanms | |
step:10/1395 train_time:24796ms step_avg:nanms | |
step:11/1395 train_time:130ms step_avg:nanms | |
step:12/1395 train_time:257ms step_avg:nanms | |
step:13/1395 train_time:381ms step_avg:127.15ms | |
step:14/1395 train_time:507ms step_avg:126.67ms | |
step:15/1395 train_time:630ms step_avg:125.91ms | |
step:16/1395 train_time:755ms step_avg:125.75ms | |
step:17/1395 train_time:879ms step_avg:125.63ms | |
step:18/1395 train_time:1007ms step_avg:125.92ms | |
step:19/1395 train_time:1135ms step_avg:126.15ms | |
step:20/1395 train_time:1261ms step_avg:126.14ms | |
step:21/1395 train_time:1387ms step_avg:126.13ms | |
step:22/1395 train_time:1512ms step_avg:126.01ms | |
step:23/1395 train_time:1636ms step_avg:125.85ms | |
step:24/1395 train_time:1762ms step_avg:125.86ms | |
step:25/1395 train_time:1887ms step_avg:125.83ms | |
step:26/1395 train_time:2012ms step_avg:125.78ms | |
step:27/1395 train_time:2141ms step_avg:125.95ms | |
step:28/1395 train_time:2267ms step_avg:125.97ms | |
step:29/1395 train_time:2394ms step_avg:126.02ms | |
step:30/1395 train_time:2521ms step_avg:126.05ms | |
step:31/1395 train_time:2647ms step_avg:126.06ms | |
step:32/1395 train_time:2772ms step_avg:125.99ms | |
step:33/1395 train_time:2898ms step_avg:126.02ms | |
step:34/1395 train_time:3025ms step_avg:126.02ms | |
step:35/1395 train_time:3151ms step_avg:126.05ms | |
step:36/1395 train_time:3277ms step_avg:126.03ms | |
step:37/1395 train_time:3403ms step_avg:126.03ms | |
step:38/1395 train_time:3528ms step_avg:126.01ms | |
step:39/1395 train_time:3653ms step_avg:125.98ms | |
step:40/1395 train_time:3778ms step_avg:125.94ms | |
step:41/1395 train_time:3905ms step_avg:125.97ms | |
step:42/1395 train_time:4030ms step_avg:125.95ms | |
step:43/1395 train_time:4156ms step_avg:125.95ms | |
step:44/1395 train_time:4282ms step_avg:125.93ms | |
step:45/1395 train_time:4409ms step_avg:125.96ms | |
step:46/1395 train_time:4535ms step_avg:125.97ms | |
step:47/1395 train_time:4661ms step_avg:125.97ms | |
step:48/1395 train_time:4787ms step_avg:125.97ms | |
step:49/1395 train_time:4912ms step_avg:125.95ms | |
step:50/1395 train_time:5039ms step_avg:125.97ms | |
step:51/1395 train_time:5166ms step_avg:126.00ms | |
step:52/1395 train_time:5290ms step_avg:125.95ms | |
step:53/1395 train_time:5417ms step_avg:125.99ms | |
step:54/1395 train_time:5544ms step_avg:126.00ms | |
step:55/1395 train_time:5669ms step_avg:125.98ms | |
step:56/1395 train_time:5794ms step_avg:125.95ms | |
step:57/1395 train_time:5921ms step_avg:125.97ms | |
step:58/1395 train_time:6047ms step_avg:125.99ms | |
step:59/1395 train_time:6172ms step_avg:125.97ms | |
step:60/1395 train_time:6299ms step_avg:125.98ms | |
step:61/1395 train_time:6426ms step_avg:125.99ms | |
step:62/1395 train_time:6550ms step_avg:125.97ms | |
step:63/1395 train_time:6675ms step_avg:125.95ms | |
step:64/1395 train_time:6801ms step_avg:125.95ms | |
step:65/1395 train_time:6927ms step_avg:125.95ms | |
step:66/1395 train_time:7053ms step_avg:125.94ms | |
step:67/1395 train_time:7178ms step_avg:125.93ms | |
step:68/1395 train_time:7304ms step_avg:125.93ms | |
step:69/1395 train_time:7429ms step_avg:125.92ms | |
step:70/1395 train_time:7556ms step_avg:125.93ms | |
step:71/1395 train_time:7681ms step_avg:125.93ms | |
step:72/1395 train_time:7806ms step_avg:125.91ms | |
step:73/1395 train_time:7932ms step_avg:125.90ms | |
step:74/1395 train_time:8058ms step_avg:125.90ms | |
step:75/1395 train_time:8183ms step_avg:125.89ms | |
step:76/1395 train_time:8309ms step_avg:125.90ms | |
step:77/1395 train_time:8435ms step_avg:125.89ms | |
step:78/1395 train_time:8560ms step_avg:125.89ms | |
step:79/1395 train_time:8687ms step_avg:125.90ms | |
step:80/1395 train_time:8814ms step_avg:125.91ms | |
step:81/1395 train_time:8939ms step_avg:125.90ms | |
step:82/1395 train_time:9066ms step_avg:125.91ms | |
step:83/1395 train_time:9191ms step_avg:125.90ms | |
step:84/1395 train_time:9318ms step_avg:125.92ms | |
step:85/1395 train_time:9443ms step_avg:125.91ms | |
step:86/1395 train_time:9569ms step_avg:125.90ms | |
step:87/1395 train_time:9694ms step_avg:125.90ms | |
step:88/1395 train_time:9821ms step_avg:125.91ms | |
step:89/1395 train_time:9947ms step_avg:125.91ms | |
step:90/1395 train_time:10072ms step_avg:125.90ms | |
step:91/1395 train_time:10197ms step_avg:125.89ms | |
step:92/1395 train_time:10324ms step_avg:125.90ms | |
step:93/1395 train_time:10450ms step_avg:125.90ms | |
step:94/1395 train_time:10574ms step_avg:125.88ms | |
step:95/1395 train_time:10700ms step_avg:125.88ms | |
step:96/1395 train_time:10826ms step_avg:125.88ms | |
step:97/1395 train_time:10951ms step_avg:125.87ms | |
step:98/1395 train_time:11076ms step_avg:125.86ms | |
step:99/1395 train_time:11202ms step_avg:125.86ms | |
step:100/1395 train_time:11327ms step_avg:125.86ms | |
step:101/1395 train_time:11452ms step_avg:125.85ms | |
step:102/1395 train_time:11578ms step_avg:125.85ms | |
step:103/1395 train_time:11705ms step_avg:125.86ms | |
step:104/1395 train_time:11830ms step_avg:125.85ms | |
step:105/1395 train_time:11957ms step_avg:125.86ms | |
step:106/1395 train_time:12085ms step_avg:125.89ms | |
step:107/1395 train_time:12213ms step_avg:125.91ms | |
step:108/1395 train_time:12343ms step_avg:125.95ms | |
step:109/1395 train_time:12470ms step_avg:125.96ms | |
step:110/1395 train_time:12599ms step_avg:125.99ms | |
step:111/1395 train_time:12727ms step_avg:126.01ms | |
step:112/1395 train_time:12855ms step_avg:126.03ms | |
step:113/1395 train_time:12984ms step_avg:126.06ms | |
step:114/1395 train_time:13113ms step_avg:126.08ms | |
step:115/1395 train_time:13242ms step_avg:126.12ms | |
step:116/1395 train_time:13371ms step_avg:126.14ms | |
step:117/1395 train_time:13500ms step_avg:126.17ms | |
step:118/1395 train_time:13627ms step_avg:126.18ms | |
step:119/1395 train_time:13755ms step_avg:126.20ms | |
step:120/1395 train_time:13884ms step_avg:126.22ms | |
step:121/1395 train_time:14012ms step_avg:126.23ms | |
step:122/1395 train_time:14142ms step_avg:126.27ms | |
step:123/1395 train_time:14271ms step_avg:126.29ms | |
step:124/1395 train_time:14399ms step_avg:126.30ms | |
step:125/1395 train_time:14526ms step_avg:126.32ms | |
step:125/1395 val_loss:4.3757 train_time:14591ms step_avg:126.88ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.012665830552577972, 0.015137831680476665, 0.03153892233967781, 0.06638150662183762, 0.06183234229683876, 0.08457919955253601, 0.12596195936203003, 0.14513088762760162, 0.1760280877351761, 0.1661289483308792, 0.09567415714263916, 0.018227510154247284, 0.0007119215442799032, 9.059501735464437e-07, 0.0, 0.0, 0.0] | |
model.lm_head.weight.data.max().item()=0.37232813239097595 | |
step:126/1395 train_time:14657ms step_avg:126.36ms | |
step:127/1395 train_time:14791ms step_avg:126.42ms | |
step:128/1395 train_time:14919ms step_avg:126.43ms | |
step:129/1395 train_time:15046ms step_avg:126.44ms | |
step:130/1395 train_time:15173ms step_avg:126.44ms | |
step:131/1395 train_time:15300ms step_avg:126.44ms | |
step:132/1395 train_time:15425ms step_avg:126.44ms | |
step:133/1395 train_time:15554ms step_avg:126.46ms | |
step:134/1395 train_time:15688ms step_avg:126.51ms | |
step:135/1395 train_time:15819ms step_avg:126.55ms | |
step:136/1395 train_time:15947ms step_avg:126.57ms | |
step:137/1395 train_time:16075ms step_avg:126.58ms | |
step:138/1395 train_time:16203ms step_avg:126.58ms | |
step:139/1395 train_time:16329ms step_avg:126.58ms | |
step:140/1395 train_time:16457ms step_avg:126.59ms | |
step:141/1395 train_time:16586ms step_avg:126.61ms | |
step:142/1395 train_time:16715ms step_avg:126.63ms | |
step:143/1395 train_time:16844ms step_avg:126.65ms | |
step:144/1395 train_time:16973ms step_avg:126.67ms | |
step:145/1395 train_time:17101ms step_avg:126.68ms | |
step:146/1395 train_time:17228ms step_avg:126.68ms | |
step:147/1395 train_time:17355ms step_avg:126.68ms | |
step:148/1395 train_time:17483ms step_avg:126.69ms | |
step:149/1395 train_time:17611ms step_avg:126.70ms | |
step:150/1395 train_time:17741ms step_avg:126.72ms | |
step:151/1395 train_time:17868ms step_avg:126.73ms | |
step:152/1395 train_time:17999ms step_avg:126.75ms | |
step:153/1395 train_time:18126ms step_avg:126.76ms | |
step:154/1395 train_time:18254ms step_avg:126.76ms | |
step:155/1395 train_time:18381ms step_avg:126.77ms | |
step:156/1395 train_time:18509ms step_avg:126.77ms | |
step:157/1395 train_time:18638ms step_avg:126.79ms | |
step:158/1395 train_time:18768ms step_avg:126.81ms | |
step:159/1395 train_time:18896ms step_avg:126.82ms | |
step:160/1395 train_time:19024ms step_avg:126.83ms | |
step:161/1395 train_time:19152ms step_avg:126.83ms | |
step:162/1395 train_time:19279ms step_avg:126.84ms | |
step:163/1395 train_time:19407ms step_avg:126.85ms | |
step:164/1395 train_time:19536ms step_avg:126.85ms | |
step:165/1395 train_time:19664ms step_avg:126.87ms | |
step:166/1395 train_time:19794ms step_avg:126.88ms | |
step:167/1395 train_time:19922ms step_avg:126.89ms | |
step:168/1395 train_time:20051ms step_avg:126.91ms | |
step:169/1395 train_time:20180ms step_avg:126.92ms | |
step:170/1395 train_time:20307ms step_avg:126.92ms | |
step:171/1395 train_time:20436ms step_avg:126.93ms | |
step:172/1395 train_time:20563ms step_avg:126.93ms | |
step:173/1395 train_time:20691ms step_avg:126.94ms | |
step:174/1395 train_time:20821ms step_avg:126.95ms | |
step:175/1395 train_time:20949ms step_avg:126.97ms | |
step:176/1395 train_time:21078ms step_avg:126.98ms | |
step:177/1395 train_time:21207ms step_avg:126.99ms | |
step:178/1395 train_time:21336ms step_avg:127.00ms | |
step:179/1395 train_time:21463ms step_avg:127.00ms | |
step:180/1395 train_time:21591ms step_avg:127.00ms | |
step:181/1395 train_time:21720ms step_avg:127.02ms | |
step:182/1395 train_time:21849ms step_avg:127.03ms | |
step:183/1395 train_time:21977ms step_avg:127.04ms | |
step:184/1395 train_time:22104ms step_avg:127.04ms | |
step:185/1395 train_time:22232ms step_avg:127.04ms | |
step:186/1395 train_time:22360ms step_avg:127.05ms | |
step:187/1395 train_time:22489ms step_avg:127.06ms | |
step:188/1395 train_time:22618ms step_avg:127.07ms | |
step:189/1395 train_time:22747ms step_avg:127.08ms | |
step:190/1395 train_time:22875ms step_avg:127.08ms | |
step:191/1395 train_time:23002ms step_avg:127.09ms | |
step:192/1395 train_time:23132ms step_avg:127.10ms | |
step:193/1395 train_time:23260ms step_avg:127.10ms | |
step:194/1395 train_time:23387ms step_avg:127.11ms | |
step:195/1395 train_time:23515ms step_avg:127.11ms | |
step:196/1395 train_time:23643ms step_avg:127.11ms | |
step:197/1395 train_time:23771ms step_avg:127.12ms | |
step:198/1395 train_time:23900ms step_avg:127.13ms | |
step:199/1395 train_time:24027ms step_avg:127.13ms | |
step:200/1395 train_time:24155ms step_avg:127.13ms | |
step:201/1395 train_time:24283ms step_avg:127.14ms | |
step:202/1395 train_time:24411ms step_avg:127.14ms | |
step:203/1395 train_time:24539ms step_avg:127.15ms | |
step:204/1395 train_time:24668ms step_avg:127.15ms | |
step:205/1395 train_time:24796ms step_avg:127.16ms | |
step:206/1395 train_time:24924ms step_avg:127.16ms | |
step:207/1395 train_time:25052ms step_avg:127.17ms | |
step:208/1395 train_time:25182ms step_avg:127.18ms | |
step:209/1395 train_time:25312ms step_avg:127.20ms | |
step:210/1395 train_time:25443ms step_avg:127.22ms | |
step:211/1395 train_time:25573ms step_avg:127.23ms | |
step:212/1395 train_time:25704ms step_avg:127.25ms | |
step:213/1395 train_time:25834ms step_avg:127.26ms | |
step:214/1395 train_time:25965ms step_avg:127.28ms | |
step:215/1395 train_time:26097ms step_avg:127.30ms | |
step:216/1395 train_time:26229ms step_avg:127.32ms | |
step:217/1395 train_time:26360ms step_avg:127.34ms | |
step:218/1395 train_time:26490ms step_avg:127.35ms | |
step:219/1395 train_time:26620ms step_avg:127.37ms | |
step:220/1395 train_time:26751ms step_avg:127.38ms | |
step:221/1395 train_time:26881ms step_avg:127.40ms | |
step:222/1395 train_time:27013ms step_avg:127.42ms | |
step:223/1395 train_time:27146ms step_avg:127.45ms | |
step:224/1395 train_time:27276ms step_avg:127.46ms | |
step:225/1395 train_time:27407ms step_avg:127.47ms | |
step:226/1395 train_time:27540ms step_avg:127.50ms | |
step:227/1395 train_time:27670ms step_avg:127.51ms | |
step:228/1395 train_time:27800ms step_avg:127.52ms | |
step:229/1395 train_time:27930ms step_avg:127.53ms | |
step:230/1395 train_time:28062ms step_avg:127.55ms | |
step:231/1395 train_time:28194ms step_avg:127.57ms | |
step:232/1395 train_time:28324ms step_avg:127.58ms | |
step:233/1395 train_time:28456ms step_avg:127.61ms | |
step:234/1395 train_time:28587ms step_avg:127.62ms | |
step:235/1395 train_time:28717ms step_avg:127.63ms | |
step:236/1395 train_time:28847ms step_avg:127.64ms | |
step:237/1395 train_time:28979ms step_avg:127.66ms | |
step:238/1395 train_time:29109ms step_avg:127.67ms | |
step:239/1395 train_time:29240ms step_avg:127.69ms | |
step:240/1395 train_time:29370ms step_avg:127.70ms | |
step:241/1395 train_time:29503ms step_avg:127.72ms | |
step:242/1395 train_time:29634ms step_avg:127.73ms | |
step:243/1395 train_time:29764ms step_avg:127.74ms | |
step:244/1395 train_time:29895ms step_avg:127.76ms | |
step:245/1395 train_time:30026ms step_avg:127.77ms | |
step:246/1395 train_time:30158ms step_avg:127.79ms | |
step:247/1395 train_time:30289ms step_avg:127.80ms | |
step:248/1395 train_time:30422ms step_avg:127.82ms | |
step:249/1395 train_time:30553ms step_avg:127.84ms | |
step:250/1395 train_time:30684ms step_avg:127.85ms | |
step:250/1395 val_loss:3.9468 train_time:30750ms step_avg:128.13ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.008707682602107525, 0.01047583855688572, 0.0217820443212986, 0.04602677375078201, 0.0432157889008522, 0.05983917415142059, 0.09203348308801651, 0.11192848533391953, 0.1522759348154068, 0.18291550874710083, 0.17257297039031982, 0.08291499316692352, 0.014976909384131432, 0.00033338964567519724, 1.0353716106692445e-06, 0.0, 0.0] | |
model.lm_head.weight.data.max().item()=0.5405111312866211 | |
step:251/1395 train_time:30819ms step_avg:127.88ms | |
step:252/1395 train_time:30951ms step_avg:127.90ms | |
step:253/1395 train_time:31082ms step_avg:127.91ms | |
step:254/1395 train_time:31210ms step_avg:127.91ms | |
step:255/1395 train_time:31339ms step_avg:127.91ms | |
step:256/1395 train_time:31468ms step_avg:127.92ms | |
step:257/1395 train_time:31595ms step_avg:127.92ms | |
step:258/1395 train_time:31728ms step_avg:127.94ms | |
step:259/1395 train_time:31863ms step_avg:127.96ms | |
step:260/1395 train_time:31996ms step_avg:127.99ms | |
step:261/1395 train_time:32126ms step_avg:127.99ms | |
step:262/1395 train_time:32255ms step_avg:128.00ms | |
step:263/1395 train_time:32384ms step_avg:128.00ms | |
step:264/1395 train_time:32513ms step_avg:128.00ms | |
step:265/1395 train_time:32643ms step_avg:128.01ms | |
step:266/1395 train_time:32774ms step_avg:128.02ms | |
step:267/1395 train_time:32906ms step_avg:128.04ms | |
step:268/1395 train_time:33039ms step_avg:128.06ms | |
step:269/1395 train_time:33169ms step_avg:128.07ms | |
step:270/1395 train_time:33299ms step_avg:128.07ms | |
step:271/1395 train_time:33428ms step_avg:128.07ms | |
step:272/1395 train_time:33557ms step_avg:128.08ms | |
step:273/1395 train_time:33687ms step_avg:128.09ms | |
step:274/1395 train_time:33817ms step_avg:128.09ms | |
step:275/1395 train_time:33949ms step_avg:128.11ms | |
step:276/1395 train_time:34081ms step_avg:128.12ms | |
step:277/1395 train_time:34212ms step_avg:128.13ms | |
step:278/1395 train_time:34342ms step_avg:128.14ms | |
step:279/1395 train_time:34471ms step_avg:128.15ms | |
step:280/1395 train_time:34601ms step_avg:128.15ms | |
step:281/1395 train_time:34731ms step_avg:128.16ms | |
step:282/1395 train_time:34864ms step_avg:128.17ms | |
step:283/1395 train_time:34995ms step_avg:128.19ms | |
step:284/1395 train_time:35128ms step_avg:128.20ms | |
step:285/1395 train_time:35259ms step_avg:128.21ms | |
step:286/1395 train_time:35388ms step_avg:128.22ms | |
step:287/1395 train_time:35518ms step_avg:128.22ms | |
step:288/1395 train_time:35649ms step_avg:128.23ms | |
step:289/1395 train_time:35779ms step_avg:128.24ms | |
step:290/1395 train_time:35910ms step_avg:128.25ms | |
step:291/1395 train_time:36043ms step_avg:128.27ms | |
step:292/1395 train_time:36173ms step_avg:128.27ms | |
step:293/1395 train_time:36305ms step_avg:128.29ms | |
step:294/1395 train_time:36434ms step_avg:128.29ms | |
step:295/1395 train_time:36565ms step_avg:128.30ms | |
step:296/1395 train_time:36695ms step_avg:128.31ms | |
step:297/1395 train_time:36828ms step_avg:128.32ms | |
step:298/1395 train_time:36960ms step_avg:128.33ms | |
step:299/1395 train_time:37090ms step_avg:128.34ms | |
step:300/1395 train_time:37222ms step_avg:128.35ms | |
step:301/1395 train_time:37351ms step_avg:128.35ms | |
step:302/1395 train_time:37481ms step_avg:128.36ms | |
step:303/1395 train_time:37612ms step_avg:128.37ms | |
step:304/1395 train_time:37741ms step_avg:128.37ms | |
step:305/1395 train_time:37870ms step_avg:128.37ms | |
step:306/1395 train_time:38003ms step_avg:128.39ms | |
step:307/1395 train_time:38134ms step_avg:128.40ms | |
step:308/1395 train_time:38265ms step_avg:128.41ms | |
step:309/1395 train_time:38395ms step_avg:128.41ms | |
step:310/1395 train_time:38526ms step_avg:128.42ms | |
step:311/1395 train_time:38656ms step_avg:128.42ms | |
step:312/1395 train_time:38786ms step_avg:128.43ms | |
step:313/1395 train_time:38919ms step_avg:128.45ms | |
step:314/1395 train_time:39052ms step_avg:128.46ms | |
step:315/1395 train_time:39185ms step_avg:128.48ms | |
step:316/1395 train_time:39318ms step_avg:128.49ms | |
step:317/1395 train_time:39451ms step_avg:128.51ms | |
step:318/1395 train_time:39585ms step_avg:128.52ms | |
step:319/1395 train_time:39717ms step_avg:128.53ms | |
step:320/1395 train_time:39849ms step_avg:128.54ms | |
step:321/1395 train_time:39982ms step_avg:128.56ms | |
step:322/1395 train_time:40115ms step_avg:128.57ms | |
step:323/1395 train_time:40247ms step_avg:128.59ms | |
step:324/1395 train_time:40380ms step_avg:128.60ms | |
step:325/1395 train_time:40512ms step_avg:128.61ms | |
step:326/1395 train_time:40645ms step_avg:128.62ms | |
step:327/1395 train_time:40779ms step_avg:128.64ms | |
step:328/1395 train_time:40913ms step_avg:128.66ms | |
step:329/1395 train_time:41047ms step_avg:128.67ms | |
step:330/1395 train_time:41181ms step_avg:128.69ms | |
step:331/1395 train_time:41312ms step_avg:128.70ms | |
step:332/1395 train_time:41444ms step_avg:128.71ms | |
step:333/1395 train_time:41576ms step_avg:128.72ms | |
step:334/1395 train_time:41707ms step_avg:128.73ms | |
step:335/1395 train_time:41841ms step_avg:128.74ms | |
step:336/1395 train_time:41974ms step_avg:128.75ms | |
step:337/1395 train_time:42107ms step_avg:128.77ms | |
step:338/1395 train_time:42240ms step_avg:128.78ms | |
step:339/1395 train_time:42372ms step_avg:128.79ms | |
step:340/1395 train_time:42504ms step_avg:128.80ms | |
step:341/1395 train_time:42637ms step_avg:128.81ms | |
step:342/1395 train_time:42768ms step_avg:128.82ms | |
step:343/1395 train_time:42904ms step_avg:128.84ms | |
step:344/1395 train_time:43037ms step_avg:128.85ms | |
step:345/1395 train_time:43171ms step_avg:128.87ms | |
step:346/1395 train_time:43305ms step_avg:128.88ms | |
step:347/1395 train_time:43438ms step_avg:128.89ms | |
step:348/1395 train_time:43570ms step_avg:128.90ms | |
step:349/1395 train_time:43702ms step_avg:128.91ms | |
step:350/1395 train_time:43833ms step_avg:128.92ms | |
step:351/1395 train_time:43965ms step_avg:128.93ms | |
step:352/1395 train_time:44097ms step_avg:128.94ms | |
step:353/1395 train_time:44230ms step_avg:128.95ms | |
step:354/1395 train_time:44362ms step_avg:128.96ms | |
step:355/1395 train_time:44495ms step_avg:128.97ms | |
step:356/1395 train_time:44628ms step_avg:128.98ms | |
step:357/1395 train_time:44760ms step_avg:128.99ms | |
step:358/1395 train_time:44893ms step_avg:129.00ms | |
step:359/1395 train_time:45028ms step_avg:129.02ms | |
step:360/1395 train_time:45161ms step_avg:129.03ms | |
step:361/1395 train_time:45294ms step_avg:129.04ms | |
step:362/1395 train_time:45427ms step_avg:129.05ms | |
step:363/1395 train_time:45558ms step_avg:129.06ms | |
step:364/1395 train_time:45690ms step_avg:129.07ms | |
step:365/1395 train_time:45825ms step_avg:129.08ms | |
step:366/1395 train_time:45957ms step_avg:129.09ms | |
step:367/1395 train_time:46089ms step_avg:129.10ms | |
step:368/1395 train_time:46222ms step_avg:129.11ms | |
step:369/1395 train_time:46355ms step_avg:129.12ms | |
step:370/1395 train_time:46488ms step_avg:129.13ms | |
step:371/1395 train_time:46621ms step_avg:129.15ms | |
step:372/1395 train_time:46755ms step_avg:129.16ms | |
step:373/1395 train_time:46888ms step_avg:129.17ms | |
step:374/1395 train_time:47022ms step_avg:129.18ms | |
step:375/1395 train_time:47155ms step_avg:129.19ms | |
step:375/1395 val_loss:3.7686 train_time:47224ms step_avg:129.38ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.006948792841285467, 0.008299409411847591, 0.017356323078274727, 0.03669121488928795, 0.0345161072909832, 0.047923494130373, 0.07445346564054489, 0.09238944202661514, 0.13034279644489288, 0.16993367671966553, 0.19259724020957947, 0.1355862319469452, 0.04924633726477623, 0.003665060270577669, 5.037082883063704e-05, 2.5884290266731114e-08, 0.0] | |
model.lm_head.weight.data.max().item()=0.6887547969818115 | |
step:376/1395 train_time:47293ms step_avg:129.22ms | |
step:377/1395 train_time:47431ms step_avg:129.24ms | |
step:378/1395 train_time:47562ms step_avg:129.25ms | |
step:379/1395 train_time:47693ms step_avg:129.25ms | |
step:380/1395 train_time:47823ms step_avg:129.25ms | |
step:381/1395 train_time:47954ms step_avg:129.26ms | |
step:382/1395 train_time:48085ms step_avg:129.26ms | |
step:383/1395 train_time:48219ms step_avg:129.27ms | |
step:384/1395 train_time:48357ms step_avg:129.30ms | |
step:385/1395 train_time:48490ms step_avg:129.31ms | |
step:386/1395 train_time:48620ms step_avg:129.31ms | |
step:387/1395 train_time:48755ms step_avg:129.32ms | |
step:388/1395 train_time:48885ms step_avg:129.32ms | |
step:389/1395 train_time:49016ms step_avg:129.33ms | |
step:390/1395 train_time:49147ms step_avg:129.34ms | |
step:391/1395 train_time:49281ms step_avg:129.35ms | |
step:392/1395 train_time:49415ms step_avg:129.36ms | |
step:393/1395 train_time:49549ms step_avg:129.37ms | |
step:394/1395 train_time:49682ms step_avg:129.38ms | |
step:395/1395 train_time:49814ms step_avg:129.39ms | |
step:396/1395 train_time:49946ms step_avg:129.39ms | |
step:397/1395 train_time:50078ms step_avg:129.40ms | |
step:398/1395 train_time:50213ms step_avg:129.41ms | |
step:399/1395 train_time:50346ms step_avg:129.42ms | |
step:400/1395 train_time:50478ms step_avg:129.43ms | |
step:401/1395 train_time:50610ms step_avg:129.44ms | |
step:402/1395 train_time:50742ms step_avg:129.44ms | |
step:403/1395 train_time:50876ms step_avg:129.45ms | |
step:404/1395 train_time:51008ms step_avg:129.46ms | |
step:405/1395 train_time:51139ms step_avg:129.47ms | |
step:406/1395 train_time:51272ms step_avg:129.47ms | |
step:407/1395 train_time:51405ms step_avg:129.48ms | |
step:408/1395 train_time:51538ms step_avg:129.49ms | |
step:409/1395 train_time:51671ms step_avg:129.50ms | |
step:410/1395 train_time:51804ms step_avg:129.51ms | |
step:411/1395 train_time:51938ms step_avg:129.52ms | |
step:412/1395 train_time:52069ms step_avg:129.53ms | |
step:413/1395 train_time:52201ms step_avg:129.53ms | |
step:414/1395 train_time:52335ms step_avg:129.54ms | |
step:415/1395 train_time:52468ms step_avg:129.55ms | |
step:416/1395 train_time:52601ms step_avg:129.56ms | |
step:417/1395 train_time:52736ms step_avg:129.57ms | |
step:418/1395 train_time:52870ms step_avg:129.58ms | |
step:419/1395 train_time:53003ms step_avg:129.59ms | |
step:420/1395 train_time:53138ms step_avg:129.61ms | |
step:421/1395 train_time:53273ms step_avg:129.62ms | |
step:422/1395 train_time:53409ms step_avg:129.63ms | |
step:423/1395 train_time:53543ms step_avg:129.64ms | |
step:424/1395 train_time:53677ms step_avg:129.66ms | |
step:425/1395 train_time:53812ms step_avg:129.67ms | |
step:426/1395 train_time:53946ms step_avg:129.68ms | |
step:427/1395 train_time:54080ms step_avg:129.69ms | |
step:428/1395 train_time:54215ms step_avg:129.70ms | |
step:429/1395 train_time:54350ms step_avg:129.71ms | |
step:430/1395 train_time:54484ms step_avg:129.72ms | |
step:431/1395 train_time:54618ms step_avg:129.73ms | |
step:432/1395 train_time:54752ms step_avg:129.74ms | |
step:433/1395 train_time:54885ms step_avg:129.75ms | |
step:434/1395 train_time:55019ms step_avg:129.76ms | |
step:435/1395 train_time:55154ms step_avg:129.77ms | |
step:436/1395 train_time:55289ms step_avg:129.79ms | |
step:437/1395 train_time:55423ms step_avg:129.80ms | |
step:438/1395 train_time:55557ms step_avg:129.81ms | |
step:439/1395 train_time:55693ms step_avg:129.82ms | |
step:440/1395 train_time:55826ms step_avg:129.83ms | |
step:441/1395 train_time:55960ms step_avg:129.84ms | |
step:442/1395 train_time:56095ms step_avg:129.85ms | |
step:443/1395 train_time:56231ms step_avg:129.86ms | |
step:444/1395 train_time:56367ms step_avg:129.88ms | |
step:445/1395 train_time:56501ms step_avg:129.89ms | |
step:446/1395 train_time:56635ms step_avg:129.90ms | |
step:447/1395 train_time:56770ms step_avg:129.91ms | |
step:448/1395 train_time:56904ms step_avg:129.92ms | |
step:449/1395 train_time:57039ms step_avg:129.93ms | |
step:450/1395 train_time:57175ms step_avg:129.94ms | |
step:451/1395 train_time:57311ms step_avg:129.96ms | |
step:452/1395 train_time:57445ms step_avg:129.97ms | |
step:453/1395 train_time:57580ms step_avg:129.98ms | |
step:454/1395 train_time:57714ms step_avg:129.99ms | |
step:455/1395 train_time:57848ms step_avg:130.00ms | |
step:456/1395 train_time:57982ms step_avg:130.00ms | |
step:457/1395 train_time:58117ms step_avg:130.02ms | |
step:458/1395 train_time:58251ms step_avg:130.03ms | |
step:459/1395 train_time:58386ms step_avg:130.03ms | |
step:460/1395 train_time:58520ms step_avg:130.05ms | |
step:461/1395 train_time:58656ms step_avg:130.06ms | |
step:462/1395 train_time:58791ms step_avg:130.07ms | |
step:463/1395 train_time:58927ms step_avg:130.08ms | |
step:464/1395 train_time:59060ms step_avg:130.09ms | |
step:465/1395 train_time:59196ms step_avg:130.10ms | |
step:466/1395 train_time:59332ms step_avg:130.11ms | |
step:467/1395 train_time:59465ms step_avg:130.12ms | |
step:468/1395 train_time:59600ms step_avg:130.13ms | |
step:469/1395 train_time:59734ms step_avg:130.14ms | |
step:470/1395 train_time:59868ms step_avg:130.15ms | |
step:471/1395 train_time:60003ms step_avg:130.16ms | |
step:472/1395 train_time:60138ms step_avg:130.17ms | |
step:473/1395 train_time:60272ms step_avg:130.18ms | |
step:474/1395 train_time:60405ms step_avg:130.18ms | |
step:475/1395 train_time:60540ms step_avg:130.19ms | |
step:476/1395 train_time:60677ms step_avg:130.21ms | |
step:477/1395 train_time:60810ms step_avg:130.21ms | |
step:478/1395 train_time:60944ms step_avg:130.22ms | |
step:479/1395 train_time:61079ms step_avg:130.23ms | |
step:480/1395 train_time:61214ms step_avg:130.24ms | |
step:481/1395 train_time:61348ms step_avg:130.25ms | |
step:482/1395 train_time:61482ms step_avg:130.26ms | |
step:483/1395 train_time:61616ms step_avg:130.27ms | |
step:484/1395 train_time:61751ms step_avg:130.28ms | |
step:485/1395 train_time:61886ms step_avg:130.29ms | |
step:486/1395 train_time:62021ms step_avg:130.30ms | |
step:487/1395 train_time:62156ms step_avg:130.31ms | |
step:488/1395 train_time:62289ms step_avg:130.31ms | |
step:489/1395 train_time:62424ms step_avg:130.32ms | |
step:490/1395 train_time:62559ms step_avg:130.33ms | |
step:491/1395 train_time:62695ms step_avg:130.34ms | |
step:492/1395 train_time:62829ms step_avg:130.35ms | |
step:493/1395 train_time:62964ms step_avg:130.36ms | |
step:494/1395 train_time:63100ms step_avg:130.37ms | |
step:495/1395 train_time:63236ms step_avg:130.38ms | |
step:496/1395 train_time:63372ms step_avg:130.40ms | |
step:497/1395 train_time:63505ms step_avg:130.40ms | |
step:498/1395 train_time:63641ms step_avg:130.41ms | |
step:499/1395 train_time:63775ms step_avg:130.42ms | |
step:500/1395 train_time:63911ms step_avg:130.43ms | |
step:500/1395 val_loss:3.6526 train_time:63979ms step_avg:130.57ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.005923283286392689, 0.007067602127790451, 0.01489938609302044, 0.03139935061335564, 0.029607875272631645, 0.04116521030664444, 0.0641791895031929, 0.08031442016363144, 0.11517877876758575, 0.15568526089191437, 0.1923529952764511, 0.16292622685432434, 0.08623387664556503, 0.012604174204170704, 0.00046123217907734215, 1.1389088285795879e-06, 0.0] | |
model.lm_head.weight.data.max().item()=0.9659891128540039 | |
step:501/1395 train_time:64047ms step_avg:130.44ms | |
step:502/1395 train_time:64183ms step_avg:130.45ms | |
step:503/1395 train_time:64317ms step_avg:130.46ms | |
step:504/1395 train_time:64450ms step_avg:130.47ms | |
step:505/1395 train_time:64582ms step_avg:130.47ms | |
step:506/1395 train_time:64715ms step_avg:130.47ms | |
step:507/1395 train_time:64850ms step_avg:130.48ms | |
step:508/1395 train_time:64987ms step_avg:130.50ms | |
step:509/1395 train_time:65123ms step_avg:130.51ms | |
step:510/1395 train_time:65259ms step_avg:130.52ms | |
step:511/1395 train_time:65393ms step_avg:130.53ms | |
step:512/1395 train_time:65527ms step_avg:130.53ms | |
step:513/1395 train_time:65659ms step_avg:130.53ms | |
step:514/1395 train_time:65793ms step_avg:130.54ms | |
step:515/1395 train_time:65928ms step_avg:130.55ms | |
step:516/1395 train_time:66063ms step_avg:130.56ms | |
step:517/1395 train_time:66200ms step_avg:130.57ms | |
step:518/1395 train_time:66338ms step_avg:130.59ms | |
step:519/1395 train_time:66473ms step_avg:130.60ms | |
step:520/1395 train_time:66609ms step_avg:130.61ms | |
step:521/1395 train_time:66743ms step_avg:130.61ms | |
step:522/1395 train_time:66879ms step_avg:130.62ms | |
step:523/1395 train_time:67016ms step_avg:130.63ms | |
step:524/1395 train_time:67151ms step_avg:130.64ms | |
step:525/1395 train_time:67286ms step_avg:130.65ms | |
step:526/1395 train_time:67422ms step_avg:130.66ms | |
step:527/1395 train_time:67558ms step_avg:130.67ms | |
step:528/1395 train_time:67692ms step_avg:130.68ms | |
step:529/1395 train_time:67828ms step_avg:130.69ms | |
step:530/1395 train_time:67964ms step_avg:130.70ms | |
step:531/1395 train_time:68101ms step_avg:130.71ms | |
step:532/1395 train_time:68238ms step_avg:130.72ms | |
step:533/1395 train_time:68373ms step_avg:130.73ms | |
step:534/1395 train_time:68508ms step_avg:130.74ms | |
step:535/1395 train_time:68643ms step_avg:130.75ms | |
step:536/1395 train_time:68780ms step_avg:130.76ms | |
step:537/1395 train_time:68915ms step_avg:130.77ms | |
step:538/1395 train_time:69052ms step_avg:130.78ms | |
step:539/1395 train_time:69189ms step_avg:130.79ms | |
step:540/1395 train_time:69325ms step_avg:130.80ms | |
step:541/1395 train_time:69461ms step_avg:130.81ms | |
step:542/1395 train_time:69597ms step_avg:130.82ms | |
step:543/1395 train_time:69733ms step_avg:130.83ms | |
step:544/1395 train_time:69870ms step_avg:130.84ms | |
step:545/1395 train_time:70007ms step_avg:130.85ms | |
step:546/1395 train_time:70142ms step_avg:130.86ms | |
step:547/1395 train_time:70278ms step_avg:130.87ms | |
step:548/1395 train_time:70414ms step_avg:130.88ms | |
step:549/1395 train_time:70551ms step_avg:130.89ms | |
step:550/1395 train_time:70686ms step_avg:130.90ms | |
step:551/1395 train_time:70821ms step_avg:130.91ms | |
step:552/1395 train_time:70958ms step_avg:130.92ms | |
step:553/1395 train_time:71093ms step_avg:130.93ms | |
step:554/1395 train_time:71229ms step_avg:130.94ms | |
step:555/1395 train_time:71366ms step_avg:130.95ms | |
step:556/1395 train_time:71502ms step_avg:130.96ms | |
step:557/1395 train_time:71638ms step_avg:130.96ms | |
step:558/1395 train_time:71773ms step_avg:130.97ms | |
step:559/1395 train_time:71909ms step_avg:130.98ms | |
step:560/1395 train_time:72045ms step_avg:130.99ms | |
step:561/1395 train_time:72180ms step_avg:131.00ms | |
step:562/1395 train_time:72316ms step_avg:131.01ms | |
step:563/1395 train_time:72452ms step_avg:131.02ms | |
step:564/1395 train_time:72588ms step_avg:131.02ms | |
step:565/1395 train_time:72723ms step_avg:131.03ms | |
step:566/1395 train_time:72858ms step_avg:131.04ms | |
step:567/1395 train_time:72994ms step_avg:131.05ms | |
step:568/1395 train_time:73130ms step_avg:131.06ms | |
step:569/1395 train_time:73265ms step_avg:131.06ms | |
step:570/1395 train_time:73400ms step_avg:131.07ms | |
step:571/1395 train_time:73538ms step_avg:131.08ms | |
step:572/1395 train_time:73675ms step_avg:131.09ms | |
step:573/1395 train_time:73811ms step_avg:131.10ms | |
step:574/1395 train_time:73946ms step_avg:131.11ms | |
step:575/1395 train_time:74083ms step_avg:131.12ms | |
step:576/1395 train_time:74220ms step_avg:131.13ms | |
step:577/1395 train_time:74358ms step_avg:131.14ms | |
step:578/1395 train_time:74493ms step_avg:131.15ms | |
step:579/1395 train_time:74629ms step_avg:131.16ms | |
step:580/1395 train_time:74765ms step_avg:131.17ms | |
step:581/1395 train_time:74903ms step_avg:131.18ms | |
step:582/1395 train_time:75040ms step_avg:131.19ms | |
step:583/1395 train_time:75176ms step_avg:131.20ms | |
step:584/1395 train_time:75312ms step_avg:131.21ms | |
step:585/1395 train_time:75448ms step_avg:131.21ms | |
step:586/1395 train_time:75583ms step_avg:131.22ms | |
step:587/1395 train_time:75718ms step_avg:131.23ms | |
step:588/1395 train_time:75852ms step_avg:131.23ms | |
step:589/1395 train_time:75988ms step_avg:131.24ms | |
step:590/1395 train_time:76123ms step_avg:131.25ms | |
step:591/1395 train_time:76260ms step_avg:131.26ms | |
step:592/1395 train_time:76396ms step_avg:131.26ms | |
step:593/1395 train_time:76531ms step_avg:131.27ms | |
step:594/1395 train_time:76667ms step_avg:131.28ms | |
step:595/1395 train_time:76803ms step_avg:131.29ms | |
step:596/1395 train_time:76940ms step_avg:131.30ms | |
step:597/1395 train_time:77076ms step_avg:131.31ms | |
step:598/1395 train_time:77212ms step_avg:131.31ms | |
step:599/1395 train_time:77346ms step_avg:131.32ms | |
step:600/1395 train_time:77482ms step_avg:131.33ms | |
step:601/1395 train_time:77618ms step_avg:131.33ms | |
step:602/1395 train_time:77754ms step_avg:131.34ms | |
step:603/1395 train_time:77890ms step_avg:131.35ms | |
step:604/1395 train_time:78026ms step_avg:131.36ms | |
step:605/1395 train_time:78163ms step_avg:131.37ms | |
step:606/1395 train_time:78300ms step_avg:131.38ms | |
step:607/1395 train_time:78437ms step_avg:131.39ms | |
step:608/1395 train_time:78573ms step_avg:131.39ms | |
step:609/1395 train_time:78709ms step_avg:131.40ms | |
step:610/1395 train_time:78844ms step_avg:131.41ms | |
step:611/1395 train_time:78981ms step_avg:131.42ms | |
step:612/1395 train_time:79118ms step_avg:131.43ms | |
step:613/1395 train_time:79254ms step_avg:131.43ms | |
step:614/1395 train_time:79389ms step_avg:131.44ms | |
step:615/1395 train_time:79526ms step_avg:131.45ms | |
step:616/1395 train_time:79661ms step_avg:131.45ms | |
step:617/1395 train_time:79798ms step_avg:131.46ms | |
step:618/1395 train_time:79934ms step_avg:131.47ms | |
step:619/1395 train_time:80069ms step_avg:131.48ms | |
step:620/1395 train_time:80204ms step_avg:131.48ms | |
step:621/1395 train_time:80342ms step_avg:131.49ms | |
step:622/1395 train_time:80477ms step_avg:131.50ms | |
step:623/1395 train_time:80615ms step_avg:131.51ms | |
step:624/1395 train_time:80752ms step_avg:131.52ms | |
step:625/1395 train_time:80889ms step_avg:131.53ms | |
step:625/1395 val_loss:3.5751 train_time:80961ms step_avg:131.64ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.005277030169963837, 0.006316051352769136, 0.013216415420174599, 0.027910694479942322, 0.026283400133252144, 0.03668111935257912, 0.05729852616786957, 0.0720311626791954, 0.10438893735408783, 0.14404450356960297, 0.1865675449371338, 0.17562399804592133, 0.11643584817647934, 0.02614072524011135, 0.001772556221112609, 1.1466740943433251e-05, 2.5884290266731114e-08] | |
model.lm_head.weight.data.max().item()=1.0818995237350464 | |
step:626/1395 train_time:81031ms step_avg:131.54ms | |
step:627/1395 train_time:81169ms step_avg:131.55ms | |
step:628/1395 train_time:81304ms step_avg:131.56ms | |
step:629/1395 train_time:81440ms step_avg:131.57ms | |
step:630/1395 train_time:81575ms step_avg:131.57ms | |
step:631/1395 train_time:81711ms step_avg:131.58ms | |
step:632/1395 train_time:81847ms step_avg:131.59ms | |
step:633/1395 train_time:81986ms step_avg:131.60ms | |
step:634/1395 train_time:82127ms step_avg:131.61ms | |
step:635/1395 train_time:82265ms step_avg:131.62ms | |
step:636/1395 train_time:82401ms step_avg:131.63ms | |
step:637/1395 train_time:82538ms step_avg:131.64ms | |
step:638/1395 train_time:82673ms step_avg:131.65ms | |
step:639/1395 train_time:82809ms step_avg:131.65ms | |
step:640/1395 train_time:82947ms step_avg:131.66ms | |
step:641/1395 train_time:83085ms step_avg:131.67ms | |
step:642/1395 train_time:83223ms step_avg:131.68ms | |
step:643/1395 train_time:83362ms step_avg:131.69ms | |
step:644/1395 train_time:83498ms step_avg:131.70ms | |
step:645/1395 train_time:83635ms step_avg:131.71ms | |
step:646/1395 train_time:83771ms step_avg:131.72ms | |
step:647/1395 train_time:83908ms step_avg:131.72ms | |
step:648/1395 train_time:84046ms step_avg:131.73ms | |
step:649/1395 train_time:84184ms step_avg:131.74ms | |
step:650/1395 train_time:84324ms step_avg:131.76ms | |
step:651/1395 train_time:84462ms step_avg:131.77ms | |
step:652/1395 train_time:84600ms step_avg:131.78ms | |
step:653/1395 train_time:84736ms step_avg:131.78ms | |
step:654/1395 train_time:84872ms step_avg:131.79ms | |
step:655/1395 train_time:85009ms step_avg:131.80ms | |
step:656/1395 train_time:85146ms step_avg:131.81ms | |
step:657/1395 train_time:85286ms step_avg:131.82ms | |
step:658/1395 train_time:85424ms step_avg:131.83ms | |
step:659/1395 train_time:85564ms step_avg:131.84ms | |
step:660/1395 train_time:85702ms step_avg:131.85ms | |
step:661/1395 train_time:85839ms step_avg:131.86ms | |
step:662/1395 train_time:85974ms step_avg:131.86ms | |
step:663/1395 train_time:86110ms step_avg:131.87ms | |
step:664/1395 train_time:86249ms step_avg:131.88ms | |
step:665/1395 train_time:86387ms step_avg:131.89ms | |
step:666/1395 train_time:86525ms step_avg:131.90ms | |
step:667/1395 train_time:86663ms step_avg:131.91ms | |
step:668/1395 train_time:86801ms step_avg:131.92ms | |
step:669/1395 train_time:86938ms step_avg:131.92ms | |
step:670/1395 train_time:87074ms step_avg:131.93ms | |
step:671/1395 train_time:87211ms step_avg:131.94ms | |
step:672/1395 train_time:87348ms step_avg:131.95ms | |
step:673/1395 train_time:87487ms step_avg:131.96ms | |
step:674/1395 train_time:87627ms step_avg:131.97ms | |
step:675/1395 train_time:87766ms step_avg:131.98ms | |
step:676/1395 train_time:87905ms step_avg:131.99ms | |
step:677/1395 train_time:88041ms step_avg:132.00ms | |
step:678/1395 train_time:88176ms step_avg:132.00ms | |
step:679/1395 train_time:88315ms step_avg:132.01ms | |
step:680/1395 train_time:88453ms step_avg:132.02ms | |
step:681/1395 train_time:88591ms step_avg:132.03ms | |
step:682/1395 train_time:88729ms step_avg:132.04ms | |
step:683/1395 train_time:88867ms step_avg:132.05ms | |
step:684/1395 train_time:89004ms step_avg:132.05ms | |
step:685/1395 train_time:89141ms step_avg:132.06ms | |
step:686/1395 train_time:89279ms step_avg:132.07ms | |
step:687/1395 train_time:89417ms step_avg:132.08ms | |
step:688/1395 train_time:89556ms step_avg:132.09ms | |
step:689/1395 train_time:89693ms step_avg:132.10ms | |
step:690/1395 train_time:89832ms step_avg:132.11ms | |
step:691/1395 train_time:89969ms step_avg:132.11ms | |
step:692/1395 train_time:90104ms step_avg:132.12ms | |
step:693/1395 train_time:90240ms step_avg:132.12ms | |
step:694/1395 train_time:90380ms step_avg:132.13ms | |
step:695/1395 train_time:90517ms step_avg:132.14ms | |
step:696/1395 train_time:90654ms step_avg:132.15ms | |
step:697/1395 train_time:90792ms step_avg:132.16ms | |
step:698/1395 train_time:90929ms step_avg:132.16ms | |
step:699/1395 train_time:91066ms step_avg:132.17ms | |
step:700/1395 train_time:91203ms step_avg:132.18ms | |
step:701/1395 train_time:91340ms step_avg:132.19ms | |
step:702/1395 train_time:91477ms step_avg:132.19ms | |
step:703/1395 train_time:91614ms step_avg:132.20ms | |
step:704/1395 train_time:91752ms step_avg:132.21ms | |
step:705/1395 train_time:91890ms step_avg:132.22ms | |
step:706/1395 train_time:92029ms step_avg:132.23ms | |
step:707/1395 train_time:92165ms step_avg:132.23ms | |
step:708/1395 train_time:92302ms step_avg:132.24ms | |
step:709/1395 train_time:92440ms step_avg:132.25ms | |
step:710/1395 train_time:92576ms step_avg:132.25ms | |
step:711/1395 train_time:92714ms step_avg:132.26ms | |
step:712/1395 train_time:92853ms step_avg:132.27ms | |
step:713/1395 train_time:92990ms step_avg:132.28ms | |
step:714/1395 train_time:93128ms step_avg:132.28ms | |
step:715/1395 train_time:93264ms step_avg:132.29ms | |
step:716/1395 train_time:93403ms step_avg:132.30ms | |
step:717/1395 train_time:93540ms step_avg:132.31ms | |
step:718/1395 train_time:93677ms step_avg:132.31ms | |
step:719/1395 train_time:93813ms step_avg:132.32ms | |
step:720/1395 train_time:93950ms step_avg:132.32ms | |
step:721/1395 train_time:94088ms step_avg:132.33ms | |
step:722/1395 train_time:94225ms step_avg:132.34ms | |
step:723/1395 train_time:94363ms step_avg:132.35ms | |
step:724/1395 train_time:94502ms step_avg:132.36ms | |
step:725/1395 train_time:94640ms step_avg:132.36ms | |
step:726/1395 train_time:94780ms step_avg:132.37ms | |
step:727/1395 train_time:94919ms step_avg:132.38ms | |
step:728/1395 train_time:95057ms step_avg:132.39ms | |
step:729/1395 train_time:95195ms step_avg:132.40ms | |
step:730/1395 train_time:95335ms step_avg:132.41ms | |
step:731/1395 train_time:95473ms step_avg:132.42ms | |
step:732/1395 train_time:95611ms step_avg:132.42ms | |
step:733/1395 train_time:95748ms step_avg:132.43ms | |
step:734/1395 train_time:95886ms step_avg:132.44ms | |
step:735/1395 train_time:96025ms step_avg:132.45ms | |
step:736/1395 train_time:96164ms step_avg:132.46ms | |
step:737/1395 train_time:96303ms step_avg:132.47ms | |
step:738/1395 train_time:96441ms step_avg:132.47ms | |
step:739/1395 train_time:96579ms step_avg:132.48ms | |
step:740/1395 train_time:96719ms step_avg:132.49ms | |
step:741/1395 train_time:96861ms step_avg:132.50ms | |
step:742/1395 train_time:97000ms step_avg:132.51ms | |
step:743/1395 train_time:97138ms step_avg:132.52ms | |
step:744/1395 train_time:97279ms step_avg:132.53ms | |
step:745/1395 train_time:97420ms step_avg:132.54ms | |
step:746/1395 train_time:97559ms step_avg:132.55ms | |
step:747/1395 train_time:97696ms step_avg:132.56ms | |
step:748/1395 train_time:97835ms step_avg:132.57ms | |
step:749/1395 train_time:97976ms step_avg:132.58ms | |
step:750/1395 train_time:98115ms step_avg:132.59ms | |
step:750/1395 val_loss:3.5226 train_time:98187ms step_avg:132.69ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.004798817913979292, 0.005726743955165148, 0.01206782553344965, 0.02532731182873249, 0.02399737760424614, 0.03350107744336128, 0.05230684578418732, 0.06599637120962143, 0.09620877355337143, 0.1344965100288391, 0.17963281273841858, 0.18106035888195038, 0.13886530697345734, 0.041544388979673386, 0.0044156527146697044, 5.381344089983031e-05, 2.5884290266731114e-08] | |
model.lm_head.weight.data.max().item()=1.2037492990493774 | |
step:751/1395 train_time:98256ms step_avg:132.60ms | |
step:752/1395 train_time:98400ms step_avg:132.61ms | |
step:753/1395 train_time:98538ms step_avg:132.62ms | |
step:754/1395 train_time:98675ms step_avg:132.63ms | |
step:755/1395 train_time:98812ms step_avg:132.63ms | |
step:756/1395 train_time:98949ms step_avg:132.64ms | |
step:757/1395 train_time:99090ms step_avg:132.65ms | |
step:758/1395 train_time:99232ms step_avg:132.66ms | |
step:759/1395 train_time:99372ms step_avg:132.67ms | |
step:760/1395 train_time:99511ms step_avg:132.68ms | |
step:761/1395 train_time:99649ms step_avg:132.69ms | |
step:762/1395 train_time:99787ms step_avg:132.70ms | |
step:763/1395 train_time:99924ms step_avg:132.70ms | |
step:764/1395 train_time:100062ms step_avg:132.71ms | |
step:765/1395 train_time:100200ms step_avg:132.71ms | |
step:766/1395 train_time:100341ms step_avg:132.73ms | |
step:767/1395 train_time:100482ms step_avg:132.74ms | |
step:768/1395 train_time:100621ms step_avg:132.75ms | |
step:769/1395 train_time:100760ms step_avg:132.75ms | |
step:770/1395 train_time:100898ms step_avg:132.76ms | |
step:771/1395 train_time:101035ms step_avg:132.77ms | |
step:772/1395 train_time:101174ms step_avg:132.77ms | |
step:773/1395 train_time:101313ms step_avg:132.78ms | |
step:774/1395 train_time:101451ms step_avg:132.79ms | |
step:775/1395 train_time:101589ms step_avg:132.80ms | |
step:776/1395 train_time:101728ms step_avg:132.80ms | |
step:777/1395 train_time:101867ms step_avg:132.81ms | |
step:778/1395 train_time:102005ms step_avg:132.82ms | |
step:779/1395 train_time:102141ms step_avg:132.82ms | |
step:780/1395 train_time:102282ms step_avg:132.83ms | |
step:781/1395 train_time:102421ms step_avg:132.84ms | |
step:782/1395 train_time:102560ms step_avg:132.85ms | |
step:783/1395 train_time:102698ms step_avg:132.86ms | |
step:784/1395 train_time:102837ms step_avg:132.86ms | |
step:785/1395 train_time:102974ms step_avg:132.87ms | |
step:786/1395 train_time:103113ms step_avg:132.88ms | |
step:787/1395 train_time:103252ms step_avg:132.89ms | |
step:788/1395 train_time:103390ms step_avg:132.89ms | |
step:789/1395 train_time:103528ms step_avg:132.90ms | |
step:790/1395 train_time:103668ms step_avg:132.91ms | |
step:791/1395 train_time:103808ms step_avg:132.92ms | |
step:792/1395 train_time:103948ms step_avg:132.93ms | |
step:793/1395 train_time:104086ms step_avg:132.93ms | |
step:794/1395 train_time:104223ms step_avg:132.94ms | |
step:795/1395 train_time:104364ms step_avg:132.95ms | |
step:796/1395 train_time:104502ms step_avg:132.95ms | |
step:797/1395 train_time:104641ms step_avg:132.96ms | |
step:798/1395 train_time:104780ms step_avg:132.97ms | |
step:799/1395 train_time:104920ms step_avg:132.98ms | |
step:800/1395 train_time:105058ms step_avg:132.98ms | |
step:801/1395 train_time:105196ms step_avg:132.99ms | |
step:802/1395 train_time:105337ms step_avg:133.00ms | |
step:803/1395 train_time:105475ms step_avg:133.01ms | |
step:804/1395 train_time:105612ms step_avg:133.01ms | |
step:805/1395 train_time:105752ms step_avg:133.02ms | |
step:806/1395 train_time:105888ms step_avg:133.03ms | |
step:807/1395 train_time:106026ms step_avg:133.03ms | |
step:808/1395 train_time:106164ms step_avg:133.04ms | |
step:809/1395 train_time:106302ms step_avg:133.04ms | |
step:810/1395 train_time:106439ms step_avg:133.05ms | |
step:811/1395 train_time:106577ms step_avg:133.06ms | |
step:812/1395 train_time:106717ms step_avg:133.06ms | |
step:813/1395 train_time:106855ms step_avg:133.07ms | |
step:814/1395 train_time:106993ms step_avg:133.08ms | |
step:815/1395 train_time:107131ms step_avg:133.08ms | |
step:816/1395 train_time:107271ms step_avg:133.09ms | |
step:817/1395 train_time:107409ms step_avg:133.10ms | |
step:818/1395 train_time:107546ms step_avg:133.10ms | |
step:819/1395 train_time:107684ms step_avg:133.11ms | |
step:820/1395 train_time:107822ms step_avg:133.11ms | |
step:821/1395 train_time:107960ms step_avg:133.12ms | |
step:822/1395 train_time:108098ms step_avg:133.13ms | |
step:823/1395 train_time:108236ms step_avg:133.13ms | |
step:824/1395 train_time:108374ms step_avg:133.14ms | |
step:825/1395 train_time:108513ms step_avg:133.14ms | |
step:826/1395 train_time:108654ms step_avg:133.15ms | |
step:827/1395 train_time:108793ms step_avg:133.16ms | |
step:828/1395 train_time:108932ms step_avg:133.17ms | |
step:829/1395 train_time:109071ms step_avg:133.18ms | |
step:830/1395 train_time:109212ms step_avg:133.19ms | |
step:831/1395 train_time:109353ms step_avg:133.19ms | |
step:832/1395 train_time:109492ms step_avg:133.20ms | |
step:833/1395 train_time:109633ms step_avg:133.21ms | |
step:834/1395 train_time:109773ms step_avg:133.22ms | |
step:835/1395 train_time:109913ms step_avg:133.23ms | |
step:836/1395 train_time:110054ms step_avg:133.24ms | |
step:837/1395 train_time:110193ms step_avg:133.24ms | |
step:838/1395 train_time:110335ms step_avg:133.26ms | |
step:839/1395 train_time:110475ms step_avg:133.26ms | |
step:840/1395 train_time:110613ms step_avg:133.27ms | |
step:841/1395 train_time:110752ms step_avg:133.28ms | |
step:842/1395 train_time:110890ms step_avg:133.28ms | |
step:843/1395 train_time:111029ms step_avg:133.29ms | |
step:844/1395 train_time:111167ms step_avg:133.29ms | |
step:845/1395 train_time:111308ms step_avg:133.30ms | |
step:846/1395 train_time:111448ms step_avg:133.31ms | |
step:847/1395 train_time:111587ms step_avg:133.32ms | |
step:848/1395 train_time:111725ms step_avg:133.32ms | |
step:849/1395 train_time:111863ms step_avg:133.33ms | |
step:850/1395 train_time:112002ms step_avg:133.34ms | |
step:851/1395 train_time:112144ms step_avg:133.35ms | |
step:852/1395 train_time:112285ms step_avg:133.36ms | |
step:853/1395 train_time:112424ms step_avg:133.36ms | |
step:854/1395 train_time:112562ms step_avg:133.37ms | |
step:855/1395 train_time:112702ms step_avg:133.38ms | |
step:856/1395 train_time:112841ms step_avg:133.38ms | |
step:857/1395 train_time:112981ms step_avg:133.39ms | |
step:858/1395 train_time:113122ms step_avg:133.40ms | |
step:859/1395 train_time:113263ms step_avg:133.41ms | |
step:860/1395 train_time:113401ms step_avg:133.41ms | |
step:861/1395 train_time:113542ms step_avg:133.42ms | |
step:862/1395 train_time:113683ms step_avg:133.43ms | |
step:863/1395 train_time:113825ms step_avg:133.44ms | |
step:864/1395 train_time:113966ms step_avg:133.45ms | |
step:865/1395 train_time:114103ms step_avg:133.45ms | |
step:866/1395 train_time:114250ms step_avg:133.47ms | |
step:867/1395 train_time:114389ms step_avg:133.48ms | |
step:868/1395 train_time:114526ms step_avg:133.48ms | |
step:869/1395 train_time:114665ms step_avg:133.49ms | |
step:870/1395 train_time:114806ms step_avg:133.50ms | |
step:871/1395 train_time:114943ms step_avg:133.50ms | |
step:872/1395 train_time:115083ms step_avg:133.51ms | |
step:873/1395 train_time:115222ms step_avg:133.51ms | |
step:874/1395 train_time:115362ms step_avg:133.52ms | |
step:875/1395 train_time:115502ms step_avg:133.53ms | |
step:875/1395 val_loss:3.4725 train_time:115575ms step_avg:133.61ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.004446480888873339, 0.005311119835823774, 0.011230391450226307, 0.023511115461587906, 0.022214597091078758, 0.031060345470905304, 0.048637643456459045, 0.0614636205136776, 0.08988146483898163, 0.1268555372953415, 0.17302775382995605, 0.18232697248458862, 0.15479038655757904, 0.05669555068016052, 0.008374577388167381, 0.0001723376044537872, 1.0353716106692445e-07] | |
model.lm_head.weight.data.max().item()=1.2647829055786133 | |
step:876/1395 train_time:115642ms step_avg:133.54ms | |
step:877/1395 train_time:115784ms step_avg:133.55ms | |
step:878/1395 train_time:115923ms step_avg:133.55ms | |
step:879/1395 train_time:116062ms step_avg:133.56ms | |
step:880/1395 train_time:116200ms step_avg:133.56ms | |
step:881/1395 train_time:116336ms step_avg:133.57ms | |
step:882/1395 train_time:116475ms step_avg:133.57ms | |
step:883/1395 train_time:116618ms step_avg:133.58ms | |
step:884/1395 train_time:116758ms step_avg:133.59ms | |
step:885/1395 train_time:116898ms step_avg:133.60ms | |
step:886/1395 train_time:117038ms step_avg:133.61ms | |
step:887/1395 train_time:117179ms step_avg:133.61ms | |
step:888/1395 train_time:117324ms step_avg:133.63ms | |
step:889/1395 train_time:117465ms step_avg:133.63ms | |
step:890/1395 train_time:117603ms step_avg:133.64ms | |
step:891/1395 train_time:117743ms step_avg:133.65ms | |
step:892/1395 train_time:117883ms step_avg:133.65ms | |
step:893/1395 train_time:118022ms step_avg:133.66ms | |
step:894/1395 train_time:118161ms step_avg:133.67ms | |
step:895/1395 train_time:118300ms step_avg:133.67ms | |
step:896/1395 train_time:118440ms step_avg:133.68ms | |
step:897/1395 train_time:118581ms step_avg:133.69ms | |
step:898/1395 train_time:118721ms step_avg:133.70ms | |
step:899/1395 train_time:118861ms step_avg:133.70ms | |
step:900/1395 train_time:119001ms step_avg:133.71ms | |
step:901/1395 train_time:119140ms step_avg:133.71ms | |
step:902/1395 train_time:119279ms step_avg:133.72ms | |
step:903/1395 train_time:119424ms step_avg:133.73ms | |
step:904/1395 train_time:119563ms step_avg:133.74ms | |
step:905/1395 train_time:119701ms step_avg:133.74ms | |
step:906/1395 train_time:119842ms step_avg:133.75ms | |
step:907/1395 train_time:119984ms step_avg:133.76ms | |
step:908/1395 train_time:120122ms step_avg:133.77ms | |
step:909/1395 train_time:120260ms step_avg:133.77ms | |
step:910/1395 train_time:120403ms step_avg:133.78ms | |
step:911/1395 train_time:120542ms step_avg:133.79ms | |
step:912/1395 train_time:120681ms step_avg:133.79ms | |
step:913/1395 train_time:120823ms step_avg:133.80ms | |
step:914/1395 train_time:120964ms step_avg:133.81ms | |
step:915/1395 train_time:121103ms step_avg:133.82ms | |
step:916/1395 train_time:121243ms step_avg:133.82ms | |
step:917/1395 train_time:121381ms step_avg:133.83ms | |
step:918/1395 train_time:121521ms step_avg:133.83ms | |
step:919/1395 train_time:121665ms step_avg:133.84ms | |
step:920/1395 train_time:121804ms step_avg:133.85ms | |
step:921/1395 train_time:121942ms step_avg:133.86ms | |
step:922/1395 train_time:122084ms step_avg:133.86ms | |
step:923/1395 train_time:122221ms step_avg:133.87ms | |
step:924/1395 train_time:122360ms step_avg:133.87ms | |
step:925/1395 train_time:122501ms step_avg:133.88ms | |
step:926/1395 train_time:122641ms step_avg:133.89ms | |
step:927/1395 train_time:122780ms step_avg:133.89ms | |
step:928/1395 train_time:122920ms step_avg:133.90ms | |
step:929/1395 train_time:123059ms step_avg:133.91ms | |
step:930/1395 train_time:123199ms step_avg:133.91ms | |
step:931/1395 train_time:123337ms step_avg:133.92ms | |
step:932/1395 train_time:123476ms step_avg:133.92ms | |
step:933/1395 train_time:123621ms step_avg:133.93ms | |
step:934/1395 train_time:123761ms step_avg:133.94ms | |
step:935/1395 train_time:123907ms step_avg:133.95ms | |
step:936/1395 train_time:124048ms step_avg:133.96ms | |
step:937/1395 train_time:124192ms step_avg:133.97ms | |
step:938/1395 train_time:124331ms step_avg:133.98ms | |
step:939/1395 train_time:124472ms step_avg:133.99ms | |
step:940/1395 train_time:124615ms step_avg:133.99ms | |
step:941/1395 train_time:124754ms step_avg:134.00ms | |
step:942/1395 train_time:124894ms step_avg:134.01ms | |
step:943/1395 train_time:125039ms step_avg:134.02ms | |
step:944/1395 train_time:125185ms step_avg:134.03ms | |
step:945/1395 train_time:125325ms step_avg:134.04ms | |
step:946/1395 train_time:125467ms step_avg:134.05ms | |
step:947/1395 train_time:125610ms step_avg:134.06ms | |
step:948/1395 train_time:125751ms step_avg:134.06ms | |
step:949/1395 train_time:125892ms step_avg:134.07ms | |
step:950/1395 train_time:126033ms step_avg:134.08ms | |
step:951/1395 train_time:126175ms step_avg:134.09ms | |
step:952/1395 train_time:126316ms step_avg:134.09ms | |
step:953/1395 train_time:126459ms step_avg:134.10ms | |
step:954/1395 train_time:126599ms step_avg:134.11ms | |
step:955/1395 train_time:126740ms step_avg:134.12ms | |
step:956/1395 train_time:126883ms step_avg:134.13ms | |
step:957/1395 train_time:127023ms step_avg:134.13ms | |
step:958/1395 train_time:127165ms step_avg:134.14ms | |
step:959/1395 train_time:127309ms step_avg:134.15ms | |
step:960/1395 train_time:127451ms step_avg:134.16ms | |
step:961/1395 train_time:127592ms step_avg:134.17ms | |
step:962/1395 train_time:127733ms step_avg:134.17ms | |
step:963/1395 train_time:127876ms step_avg:134.18ms | |
step:964/1395 train_time:128017ms step_avg:134.19ms | |
step:965/1395 train_time:128158ms step_avg:134.20ms | |
step:966/1395 train_time:128299ms step_avg:134.20ms | |
step:967/1395 train_time:128439ms step_avg:134.21ms | |
step:968/1395 train_time:128580ms step_avg:134.22ms | |
step:969/1395 train_time:128723ms step_avg:134.23ms | |
step:970/1395 train_time:128864ms step_avg:134.23ms | |
step:971/1395 train_time:129004ms step_avg:134.24ms | |
step:972/1395 train_time:129145ms step_avg:134.25ms | |
step:973/1395 train_time:129285ms step_avg:134.25ms | |
step:974/1395 train_time:129427ms step_avg:134.26ms | |
step:975/1395 train_time:129567ms step_avg:134.27ms | |
step:976/1395 train_time:129708ms step_avg:134.27ms | |
step:977/1395 train_time:129848ms step_avg:134.28ms | |
step:978/1395 train_time:129988ms step_avg:134.29ms | |
step:979/1395 train_time:130132ms step_avg:134.30ms | |
step:980/1395 train_time:130269ms step_avg:134.30ms | |
step:981/1395 train_time:130407ms step_avg:134.30ms | |
step:982/1395 train_time:130546ms step_avg:134.31ms | |
step:983/1395 train_time:130686ms step_avg:134.31ms | |
step:984/1395 train_time:130827ms step_avg:134.32ms | |
step:985/1395 train_time:130967ms step_avg:134.33ms | |
step:986/1395 train_time:131111ms step_avg:134.33ms | |
step:987/1395 train_time:131251ms step_avg:134.34ms | |
step:988/1395 train_time:131392ms step_avg:134.35ms | |
step:989/1395 train_time:131533ms step_avg:134.35ms | |
step:990/1395 train_time:131674ms step_avg:134.36ms | |
step:991/1395 train_time:131814ms step_avg:134.37ms | |
step:992/1395 train_time:131957ms step_avg:134.38ms | |
step:993/1395 train_time:132102ms step_avg:134.39ms | |
step:994/1395 train_time:132243ms step_avg:134.39ms | |
step:995/1395 train_time:132383ms step_avg:134.40ms | |
step:996/1395 train_time:132521ms step_avg:134.40ms | |
step:997/1395 train_time:132660ms step_avg:134.41ms | |
step:998/1395 train_time:132798ms step_avg:134.41ms | |
step:999/1395 train_time:132937ms step_avg:134.42ms | |
step:1000/1395 train_time:133079ms step_avg:134.42ms | |
step:1000/1395 val_loss:3.4096 train_time:133153ms step_avg:134.50ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.004248595796525478, 0.005089110229164362, 0.010715863667428493, 0.022450583055615425, 0.02121160551905632, 0.029677322134375572, 0.04659539833664894, 0.058907650411129, 0.0863812118768692, 0.12231044471263885, 0.16876192390918732, 0.1819356083869934, 0.16289879381656647, 0.06657664477825165, 0.011903071776032448, 0.00033571923268027604, 4.6591722480116005e-07] | |
model.lm_head.weight.data.max().item()=1.17340886592865 | |
step:1001/1395 train_time:133222ms step_avg:134.43ms | |
step:1002/1395 train_time:133362ms step_avg:134.44ms | |
step:1003/1395 train_time:133505ms step_avg:134.45ms | |
step:1004/1395 train_time:133644ms step_avg:134.45ms | |
step:1005/1395 train_time:133787ms step_avg:134.46ms | |
step:1006/1395 train_time:133924ms step_avg:134.46ms | |
step:1007/1395 train_time:134063ms step_avg:134.47ms | |
step:1008/1395 train_time:134207ms step_avg:134.48ms | |
step:1009/1395 train_time:134353ms step_avg:134.49ms | |
step:1010/1395 train_time:134491ms step_avg:134.49ms | |
step:1011/1395 train_time:134630ms step_avg:134.50ms | |
step:1012/1395 train_time:134768ms step_avg:134.50ms | |
step:1013/1395 train_time:134910ms step_avg:134.51ms | |
step:1014/1395 train_time:135050ms step_avg:134.51ms | |
step:1015/1395 train_time:135191ms step_avg:134.52ms | |
step:1016/1395 train_time:135334ms step_avg:134.53ms | |
step:1017/1395 train_time:135479ms step_avg:134.54ms | |
step:1018/1395 train_time:135620ms step_avg:134.54ms | |
step:1019/1395 train_time:135762ms step_avg:134.55ms | |
step:1020/1395 train_time:135903ms step_avg:134.56ms | |
step:1021/1395 train_time:136043ms step_avg:134.56ms | |
step:1022/1395 train_time:136184ms step_avg:134.57ms | |
step:1023/1395 train_time:136326ms step_avg:134.58ms | |
step:1024/1395 train_time:136466ms step_avg:134.58ms | |
step:1025/1395 train_time:136608ms step_avg:134.59ms | |
step:1026/1395 train_time:136747ms step_avg:134.59ms | |
step:1027/1395 train_time:136886ms step_avg:134.60ms | |
step:1028/1395 train_time:137029ms step_avg:134.61ms | |
step:1029/1395 train_time:137175ms step_avg:134.62ms | |
step:1030/1395 train_time:137316ms step_avg:134.62ms | |
step:1031/1395 train_time:137454ms step_avg:134.63ms | |
step:1032/1395 train_time:137593ms step_avg:134.63ms | |
step:1033/1395 train_time:137734ms step_avg:134.64ms | |
step:1034/1395 train_time:137876ms step_avg:134.64ms | |
step:1035/1395 train_time:138017ms step_avg:134.65ms | |
step:1036/1395 train_time:138156ms step_avg:134.66ms | |
step:1037/1395 train_time:138300ms step_avg:134.66ms | |
step:1038/1395 train_time:138445ms step_avg:134.67ms | |
step:1039/1395 train_time:138584ms step_avg:134.68ms | |
step:1040/1395 train_time:138725ms step_avg:134.68ms | |
step:1041/1395 train_time:138867ms step_avg:134.69ms | |
step:1042/1395 train_time:139007ms step_avg:134.70ms | |
step:1043/1395 train_time:139149ms step_avg:134.70ms | |
step:1044/1395 train_time:139293ms step_avg:134.71ms | |
step:1045/1395 train_time:139436ms step_avg:134.72ms | |
step:1046/1395 train_time:139579ms step_avg:134.73ms | |
step:1047/1395 train_time:139721ms step_avg:134.74ms | |
step:1048/1395 train_time:139863ms step_avg:134.74ms | |
step:1049/1395 train_time:140004ms step_avg:134.75ms | |
step:1050/1395 train_time:140147ms step_avg:134.76ms | |
step:1051/1395 train_time:140290ms step_avg:134.76ms | |
step:1052/1395 train_time:140431ms step_avg:134.77ms | |
step:1053/1395 train_time:140573ms step_avg:134.78ms | |
step:1054/1395 train_time:140714ms step_avg:134.78ms | |
step:1055/1395 train_time:140854ms step_avg:134.79ms | |
step:1056/1395 train_time:140995ms step_avg:134.79ms | |
step:1057/1395 train_time:141137ms step_avg:134.80ms | |
step:1058/1395 train_time:141278ms step_avg:134.81ms | |
step:1059/1395 train_time:141423ms step_avg:134.82ms | |
step:1060/1395 train_time:141566ms step_avg:134.82ms | |
step:1061/1395 train_time:141704ms step_avg:134.83ms | |
step:1062/1395 train_time:141846ms step_avg:134.83ms | |
step:1063/1395 train_time:141985ms step_avg:134.84ms | |
step:1064/1395 train_time:142128ms step_avg:134.85ms | |
step:1065/1395 train_time:142267ms step_avg:134.85ms | |
step:1066/1395 train_time:142409ms step_avg:134.86ms | |
step:1067/1395 train_time:142552ms step_avg:134.86ms | |
step:1068/1395 train_time:142692ms step_avg:134.87ms | |
step:1069/1395 train_time:142837ms step_avg:134.88ms | |
step:1070/1395 train_time:142976ms step_avg:134.88ms | |
step:1071/1395 train_time:143123ms step_avg:134.89ms | |
step:1072/1395 train_time:143266ms step_avg:134.90ms | |
step:1073/1395 train_time:143405ms step_avg:134.91ms | |
step:1074/1395 train_time:143545ms step_avg:134.91ms | |
step:1075/1395 train_time:143692ms step_avg:134.92ms | |
step:1076/1395 train_time:143834ms step_avg:134.93ms | |
step:1077/1395 train_time:143975ms step_avg:134.93ms | |
step:1078/1395 train_time:144117ms step_avg:134.94ms | |
step:1079/1395 train_time:144264ms step_avg:134.95ms | |
step:1080/1395 train_time:144406ms step_avg:134.96ms | |
step:1081/1395 train_time:144549ms step_avg:134.97ms | |
step:1082/1395 train_time:144690ms step_avg:134.97ms | |
step:1083/1395 train_time:144831ms step_avg:134.98ms | |
step:1084/1395 train_time:144977ms step_avg:134.99ms | |
step:1085/1395 train_time:145120ms step_avg:135.00ms | |
step:1086/1395 train_time:145262ms step_avg:135.00ms | |
step:1087/1395 train_time:145406ms step_avg:135.01ms | |
step:1088/1395 train_time:145549ms step_avg:135.02ms | |
step:1089/1395 train_time:145692ms step_avg:135.02ms | |
step:1090/1395 train_time:145838ms step_avg:135.03ms | |
step:1091/1395 train_time:145983ms step_avg:135.04ms | |
step:1092/1395 train_time:146123ms step_avg:135.05ms | |
step:1093/1395 train_time:146264ms step_avg:135.05ms | |
step:1094/1395 train_time:146406ms step_avg:135.06ms | |
step:1095/1395 train_time:146548ms step_avg:135.07ms | |
step:1096/1395 train_time:146690ms step_avg:135.07ms | |
step:1097/1395 train_time:146835ms step_avg:135.08ms | |
step:1098/1395 train_time:146977ms step_avg:135.09ms | |
step:1099/1395 train_time:147122ms step_avg:135.10ms | |
step:1100/1395 train_time:147262ms step_avg:135.10ms | |
step:1101/1395 train_time:147404ms step_avg:135.11ms | |
step:1102/1395 train_time:147547ms step_avg:135.12ms | |
step:1103/1395 train_time:147691ms step_avg:135.12ms | |
step:1104/1395 train_time:147833ms step_avg:135.13ms | |
step:1105/1395 train_time:147974ms step_avg:135.14ms | |
step:1106/1395 train_time:148117ms step_avg:135.14ms | |
step:1107/1395 train_time:148260ms step_avg:135.15ms | |
step:1108/1395 train_time:148406ms step_avg:135.16ms | |
step:1109/1395 train_time:148547ms step_avg:135.17ms | |
step:1110/1395 train_time:148688ms step_avg:135.17ms | |
step:1111/1395 train_time:148829ms step_avg:135.18ms | |
step:1112/1395 train_time:148969ms step_avg:135.18ms | |
step:1113/1395 train_time:149110ms step_avg:135.19ms | |
step:1114/1395 train_time:149254ms step_avg:135.19ms | |
step:1115/1395 train_time:149397ms step_avg:135.20ms | |
step:1116/1395 train_time:149542ms step_avg:135.21ms | |
step:1117/1395 train_time:149684ms step_avg:135.22ms | |
step:1118/1395 train_time:149831ms step_avg:135.23ms | |
step:1119/1395 train_time:149971ms step_avg:135.23ms | |
step:1120/1395 train_time:150113ms step_avg:135.24ms | |
step:1121/1395 train_time:150255ms step_avg:135.24ms | |
step:1122/1395 train_time:150394ms step_avg:135.25ms | |
step:1123/1395 train_time:150535ms step_avg:135.25ms | |
step:1124/1395 train_time:150676ms step_avg:135.26ms | |
step:1125/1395 train_time:150817ms step_avg:135.26ms | |
step:1125/1395 val_loss:3.3609 train_time:150892ms step_avg:135.33ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.004146896302700043, 0.004998903721570969, 0.010501385666429996, 0.021988730877637863, 0.0207810215651989, 0.029128756374120712, 0.04559807479381561, 0.057823147624731064, 0.08483366668224335, 0.12021270394325256, 0.16654902696609497, 0.18150299787521362, 0.1661374568939209, 0.07131919264793396, 0.014010156504809856, 0.0004671337956096977, 7.506444035243476e-07] | |
model.lm_head.weight.data.max().item()=1.2437613010406494 | |
step:1126/1395 train_time:150961ms step_avg:135.27ms | |
step:1127/1395 train_time:151103ms step_avg:135.28ms | |
step:1128/1395 train_time:151247ms step_avg:135.28ms | |
step:1129/1395 train_time:151391ms step_avg:135.29ms | |
step:1130/1395 train_time:151530ms step_avg:135.29ms | |
step:1131/1395 train_time:151672ms step_avg:135.30ms | |
step:1132/1395 train_time:151811ms step_avg:135.30ms | |
step:1133/1395 train_time:151956ms step_avg:135.31ms | |
step:1134/1395 train_time:152101ms step_avg:135.32ms | |
step:1135/1395 train_time:152242ms step_avg:135.33ms | |
step:1136/1395 train_time:152387ms step_avg:135.34ms | |
step:1137/1395 train_time:152528ms step_avg:135.34ms | |
step:1138/1395 train_time:152671ms step_avg:135.35ms | |
step:1139/1395 train_time:152814ms step_avg:135.35ms | |
step:1140/1395 train_time:152956ms step_avg:135.36ms | |
step:1141/1395 train_time:153100ms step_avg:135.37ms | |
step:1142/1395 train_time:153244ms step_avg:135.37ms | |
step:1143/1395 train_time:153386ms step_avg:135.38ms | |
step:1144/1395 train_time:153530ms step_avg:135.39ms | |
step:1145/1395 train_time:153670ms step_avg:135.39ms | |
step:1146/1395 train_time:153814ms step_avg:135.40ms | |
step:1147/1395 train_time:153957ms step_avg:135.41ms | |
step:1148/1395 train_time:154101ms step_avg:135.41ms | |
step:1149/1395 train_time:154246ms step_avg:135.42ms | |
step:1150/1395 train_time:154388ms step_avg:135.43ms | |
step:1151/1395 train_time:154533ms step_avg:135.44ms | |
step:1152/1395 train_time:154675ms step_avg:135.44ms | |
step:1153/1395 train_time:154820ms step_avg:135.45ms | |
step:1154/1395 train_time:154960ms step_avg:135.45ms | |
step:1155/1395 train_time:155104ms step_avg:135.46ms | |
step:1156/1395 train_time:155255ms step_avg:135.48ms | |
step:1157/1395 train_time:155398ms step_avg:135.48ms | |
step:1158/1395 train_time:155538ms step_avg:135.49ms | |
step:1159/1395 train_time:155680ms step_avg:135.49ms | |
step:1160/1395 train_time:155821ms step_avg:135.50ms | |
step:1161/1395 train_time:155962ms step_avg:135.50ms | |
step:1162/1395 train_time:156106ms step_avg:135.51ms | |
step:1163/1395 train_time:156251ms step_avg:135.52ms | |
step:1164/1395 train_time:156393ms step_avg:135.52ms | |
step:1165/1395 train_time:156535ms step_avg:135.53ms | |
step:1166/1395 train_time:156676ms step_avg:135.53ms | |
step:1167/1395 train_time:156821ms step_avg:135.54ms | |
step:1168/1395 train_time:156964ms step_avg:135.55ms | |
step:1169/1395 train_time:157107ms step_avg:135.55ms | |
step:1170/1395 train_time:157252ms step_avg:135.56ms | |
step:1171/1395 train_time:157393ms step_avg:135.57ms | |
step:1172/1395 train_time:157535ms step_avg:135.57ms | |
step:1173/1395 train_time:157678ms step_avg:135.58ms | |
step:1174/1395 train_time:157830ms step_avg:135.59ms | |
step:1175/1395 train_time:157973ms step_avg:135.60ms | |
step:1176/1395 train_time:158117ms step_avg:135.61ms | |
step:1177/1395 train_time:158268ms step_avg:135.62ms | |
step:1178/1395 train_time:158409ms step_avg:135.62ms | |
step:1179/1395 train_time:158547ms step_avg:135.63ms | |
step:1180/1395 train_time:158696ms step_avg:135.64ms | |
step:1181/1395 train_time:158840ms step_avg:135.65ms | |
step:1182/1395 train_time:158981ms step_avg:135.65ms | |
step:1183/1395 train_time:159127ms step_avg:135.66ms | |
step:1184/1395 train_time:159269ms step_avg:135.66ms | |
step:1185/1395 train_time:159413ms step_avg:135.67ms | |
step:1186/1395 train_time:159556ms step_avg:135.68ms | |
step:1187/1395 train_time:159709ms step_avg:135.69ms | |
step:1188/1395 train_time:159849ms step_avg:135.70ms | |
step:1189/1395 train_time:159995ms step_avg:135.70ms | |
step:1190/1395 train_time:160137ms step_avg:135.71ms | |
step:1191/1395 train_time:160280ms step_avg:135.72ms | |
step:1192/1395 train_time:160422ms step_avg:135.72ms | |
step:1193/1395 train_time:160565ms step_avg:135.73ms | |
step:1194/1395 train_time:160707ms step_avg:135.73ms | |
step:1195/1395 train_time:160851ms step_avg:135.74ms | |
step:1196/1395 train_time:160994ms step_avg:135.75ms | |
step:1197/1395 train_time:161137ms step_avg:135.75ms | |
step:1198/1395 train_time:161290ms step_avg:135.77ms | |
step:1199/1395 train_time:161432ms step_avg:135.77ms | |
step:1200/1395 train_time:161575ms step_avg:135.78ms | |
step:1201/1395 train_time:161718ms step_avg:135.78ms | |
step:1202/1395 train_time:161872ms step_avg:135.80ms | |
step:1203/1395 train_time:162018ms step_avg:135.81ms | |
step:1204/1395 train_time:162164ms step_avg:135.82ms | |
step:1205/1395 train_time:162308ms step_avg:135.82ms | |
step:1206/1395 train_time:162451ms step_avg:135.83ms | |
step:1207/1395 train_time:162592ms step_avg:135.83ms | |
step:1208/1395 train_time:162736ms step_avg:135.84ms | |
step:1209/1395 train_time:162881ms step_avg:135.85ms | |
step:1210/1395 train_time:163024ms step_avg:135.85ms | |
step:1211/1395 train_time:163168ms step_avg:135.86ms | |
step:1212/1395 train_time:163311ms step_avg:135.87ms | |
step:1213/1395 train_time:163454ms step_avg:135.87ms | |
step:1214/1395 train_time:163599ms step_avg:135.88ms | |
step:1215/1395 train_time:163747ms step_avg:135.89ms | |
step:1216/1395 train_time:163888ms step_avg:135.89ms | |
step:1217/1395 train_time:164036ms step_avg:135.90ms | |
step:1218/1395 train_time:164177ms step_avg:135.91ms | |
step:1219/1395 train_time:164319ms step_avg:135.91ms | |
step:1220/1395 train_time:164462ms step_avg:135.92ms | |
step:1221/1395 train_time:164603ms step_avg:135.92ms | |
step:1222/1395 train_time:164744ms step_avg:135.93ms | |
step:1223/1395 train_time:164886ms step_avg:135.93ms | |
step:1224/1395 train_time:165034ms step_avg:135.94ms | |
step:1225/1395 train_time:165177ms step_avg:135.95ms | |
step:1226/1395 train_time:165318ms step_avg:135.95ms | |
step:1227/1395 train_time:165460ms step_avg:135.96ms | |
step:1228/1395 train_time:165604ms step_avg:135.96ms | |
step:1229/1395 train_time:165746ms step_avg:135.97ms | |
step:1230/1395 train_time:165893ms step_avg:135.98ms | |
step:1231/1395 train_time:166039ms step_avg:135.99ms | |
step:1232/1395 train_time:166184ms step_avg:135.99ms | |
step:1233/1395 train_time:166324ms step_avg:136.00ms | |
step:1234/1395 train_time:166466ms step_avg:136.00ms | |
step:1235/1395 train_time:166607ms step_avg:136.01ms | |
step:1236/1395 train_time:166749ms step_avg:136.01ms | |
step:1237/1395 train_time:166892ms step_avg:136.02ms | |
step:1238/1395 train_time:167043ms step_avg:136.03ms | |
step:1239/1395 train_time:167184ms step_avg:136.03ms | |
step:1240/1395 train_time:167328ms step_avg:136.04ms | |
step:1241/1395 train_time:167472ms step_avg:136.05ms | |
step:1242/1395 train_time:167614ms step_avg:136.05ms | |
step:1243/1395 train_time:167762ms step_avg:136.06ms | |
step:1244/1395 train_time:167905ms step_avg:136.07ms | |
step:1245/1395 train_time:168047ms step_avg:136.07ms | |
step:1246/1395 train_time:168189ms step_avg:136.08ms | |
step:1247/1395 train_time:168332ms step_avg:136.08ms | |
step:1248/1395 train_time:168475ms step_avg:136.09ms | |
step:1249/1395 train_time:168616ms step_avg:136.09ms | |
step:1250/1395 train_time:168759ms step_avg:136.10ms | |
step:1250/1395 val_loss:3.3144 train_time:168837ms step_avg:136.16ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.004117154981940985, 0.004974002949893475, 0.010423370636999607, 0.021881569176912308, 0.020692160353064537, 0.02892760932445526, 0.045347362756729126, 0.05743749812245369, 0.08432206511497498, 0.11967422813177109, 0.16581445932388306, 0.1812158077955246, 0.16699260473251343, 0.07282008975744247, 0.014826107770204544, 0.0005324657540768385, 1.4495202549369424e-06] | |
model.lm_head.weight.data.max().item()=1.2480268478393555 | |
step:1251/1395 train_time:168910ms step_avg:136.11ms | |
step:1252/1395 train_time:169052ms step_avg:136.11ms | |
step:1253/1395 train_time:169194ms step_avg:136.12ms | |
step:1254/1395 train_time:169337ms step_avg:136.12ms | |
step:1255/1395 train_time:169491ms step_avg:136.14ms | |
step:1256/1395 train_time:169633ms step_avg:136.14ms | |
step:1257/1395 train_time:169778ms step_avg:136.15ms | |
step:1258/1395 train_time:169926ms step_avg:136.16ms | |
step:1259/1395 train_time:170072ms step_avg:136.17ms | |
step:1260/1395 train_time:170215ms step_avg:136.17ms | |
step:1261/1395 train_time:170358ms step_avg:136.18ms | |
step:1262/1395 train_time:170505ms step_avg:136.19ms | |
step:1263/1395 train_time:170648ms step_avg:136.19ms | |
step:1264/1395 train_time:170791ms step_avg:136.20ms | |
step:1265/1395 train_time:170933ms step_avg:136.20ms | |
step:1266/1395 train_time:171078ms step_avg:136.21ms | |
step:1267/1395 train_time:171221ms step_avg:136.21ms | |
step:1268/1395 train_time:171365ms step_avg:136.22ms | |
step:1269/1395 train_time:171512ms step_avg:136.23ms | |
step:1270/1395 train_time:171655ms step_avg:136.23ms | |
step:1271/1395 train_time:171797ms step_avg:136.24ms | |
step:1272/1395 train_time:171937ms step_avg:136.24ms | |
step:1273/1395 train_time:172080ms step_avg:136.25ms | |
step:1274/1395 train_time:172225ms step_avg:136.25ms | |
step:1275/1395 train_time:172366ms step_avg:136.26ms | |
step:1276/1395 train_time:172507ms step_avg:136.26ms | |
step:1277/1395 train_time:172652ms step_avg:136.27ms | |
step:1278/1395 train_time:172797ms step_avg:136.27ms | |
step:1279/1395 train_time:172940ms step_avg:136.28ms | |
step:1280/1395 train_time:173089ms step_avg:136.29ms | |
step:1281/1395 train_time:173232ms step_avg:136.30ms | |
step:1282/1395 train_time:173374ms step_avg:136.30ms | |
step:1283/1395 train_time:173516ms step_avg:136.30ms | |
step:1284/1395 train_time:173659ms step_avg:136.31ms | |
step:1285/1395 train_time:173801ms step_avg:136.31ms | |
step:1286/1395 train_time:173945ms step_avg:136.32ms | |
step:1287/1395 train_time:174090ms step_avg:136.33ms | |
step:1288/1395 train_time:174236ms step_avg:136.33ms | |
step:1289/1395 train_time:174387ms step_avg:136.35ms | |
step:1290/1395 train_time:174537ms step_avg:136.36ms | |
step:1291/1395 train_time:174682ms step_avg:136.36ms | |
step:1292/1395 train_time:174826ms step_avg:136.37ms | |
step:1293/1395 train_time:174973ms step_avg:136.38ms | |
step:1294/1395 train_time:175117ms step_avg:136.38ms | |
step:1295/1395 train_time:175262ms step_avg:136.39ms | |
step:1296/1395 train_time:175406ms step_avg:136.40ms | |
step:1297/1395 train_time:175552ms step_avg:136.40ms | |
step:1298/1395 train_time:175696ms step_avg:136.41ms | |
step:1299/1395 train_time:175841ms step_avg:136.42ms | |
step:1300/1395 train_time:175983ms step_avg:136.42ms | |
step:1301/1395 train_time:176125ms step_avg:136.43ms | |
step:1302/1395 train_time:176271ms step_avg:136.43ms | |
step:1303/1395 train_time:176419ms step_avg:136.44ms | |
step:1304/1395 train_time:176564ms step_avg:136.45ms | |
step:1305/1395 train_time:176708ms step_avg:136.45ms | |
step:1306/1395 train_time:176856ms step_avg:136.46ms | |
step:1307/1395 train_time:176999ms step_avg:136.47ms | |
step:1308/1395 train_time:177144ms step_avg:136.47ms | |
step:1309/1395 train_time:177288ms step_avg:136.48ms | |
step:1310/1395 train_time:177431ms step_avg:136.49ms | |
step:1311/1395 train_time:177571ms step_avg:136.49ms | |
step:1312/1395 train_time:177717ms step_avg:136.50ms | |
step:1313/1395 train_time:177861ms step_avg:136.50ms | |
step:1314/1395 train_time:178005ms step_avg:136.51ms | |
step:1315/1395 train_time:178149ms step_avg:136.51ms | |
step:1316/1395 train_time:178291ms step_avg:136.52ms | |
step:1317/1395 train_time:178436ms step_avg:136.52ms | |
step:1318/1395 train_time:178584ms step_avg:136.53ms | |
step:1319/1395 train_time:178727ms step_avg:136.54ms | |
step:1320/1395 train_time:178869ms step_avg:136.54ms | |
step:1321/1395 train_time:179013ms step_avg:136.55ms | |
step:1322/1395 train_time:179163ms step_avg:136.56ms | |
step:1323/1395 train_time:179305ms step_avg:136.56ms | |
step:1324/1395 train_time:179451ms step_avg:136.57ms | |
step:1325/1395 train_time:179596ms step_avg:136.57ms | |
step:1326/1395 train_time:179743ms step_avg:136.58ms | |
step:1327/1395 train_time:179886ms step_avg:136.59ms | |
step:1328/1395 train_time:180027ms step_avg:136.59ms | |
step:1329/1395 train_time:180181ms step_avg:136.60ms | |
step:1330/1395 train_time:180328ms step_avg:136.61ms | |
step:1331/1395 train_time:180473ms step_avg:136.62ms | |
step:1332/1395 train_time:180625ms step_avg:136.63ms | |
step:1333/1395 train_time:180771ms step_avg:136.64ms | |
step:1334/1395 train_time:180913ms step_avg:136.64ms | |
step:1335/1395 train_time:181055ms step_avg:136.65ms | |
step:1336/1395 train_time:181204ms step_avg:136.65ms | |
step:1337/1395 train_time:181348ms step_avg:136.66ms | |
step:1338/1395 train_time:181490ms step_avg:136.66ms | |
step:1339/1395 train_time:181636ms step_avg:136.67ms | |
step:1340/1395 train_time:181784ms step_avg:136.68ms | |
step:1341/1395 train_time:181925ms step_avg:136.68ms | |
step:1342/1395 train_time:182067ms step_avg:136.69ms | |
step:1343/1395 train_time:182209ms step_avg:136.69ms | |
step:1344/1395 train_time:182353ms step_avg:136.70ms | |
step:1345/1395 train_time:182494ms step_avg:136.70ms | |
step:1346/1395 train_time:182640ms step_avg:136.71ms | |
step:1347/1395 train_time:182786ms step_avg:136.71ms | |
step:1348/1395 train_time:182929ms step_avg:136.72ms | |
step:1349/1395 train_time:183072ms step_avg:136.72ms | |
step:1350/1395 train_time:183216ms step_avg:136.73ms | |
step:1351/1395 train_time:183361ms step_avg:136.73ms | |
step:1352/1395 train_time:183510ms step_avg:136.74ms | |
step:1353/1395 train_time:183657ms step_avg:136.75ms | |
step:1354/1395 train_time:183800ms step_avg:136.76ms | |
step:1355/1395 train_time:183943ms step_avg:136.76ms | |
step:1356/1395 train_time:184084ms step_avg:136.76ms | |
step:1357/1395 train_time:184231ms step_avg:136.77ms | |
step:1358/1395 train_time:184377ms step_avg:136.78ms | |
step:1359/1395 train_time:184521ms step_avg:136.78ms | |
step:1360/1395 train_time:184667ms step_avg:136.79ms | |
step:1361/1395 train_time:184815ms step_avg:136.80ms | |
step:1362/1395 train_time:184963ms step_avg:136.81ms | |
step:1363/1395 train_time:185112ms step_avg:136.82ms | |
step:1364/1395 train_time:185254ms step_avg:136.82ms | |
step:1365/1395 train_time:185395ms step_avg:136.82ms | |
step:1366/1395 train_time:185540ms step_avg:136.83ms | |
step:1367/1395 train_time:185682ms step_avg:136.83ms | |
step:1368/1395 train_time:185826ms step_avg:136.84ms | |
step:1369/1395 train_time:185979ms step_avg:136.85ms | |
step:1370/1395 train_time:186128ms step_avg:136.86ms | |
step:1371/1395 train_time:186271ms step_avg:136.86ms | |
step:1372/1395 train_time:186419ms step_avg:136.87ms | |
step:1373/1395 train_time:186562ms step_avg:136.88ms | |
step:1374/1395 train_time:186709ms step_avg:136.88ms | |
step:1375/1395 train_time:186853ms step_avg:136.89ms | |
step:1375/1395 val_loss:3.2805 train_time:186927ms step_avg:136.94ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.004116689320653677, 0.004960491321980953, 0.010427771136164665, 0.021892674267292023, 0.020712971687316895, 0.02899925783276558, 0.045392658561468124, 0.057472314685583115, 0.08440131694078445, 0.11972377449274063, 0.16585224866867065, 0.18107904493808746, 0.16679151356220245, 0.07271187007427216, 0.01491724606603384, 0.0005466762231662869, 1.5012888070486952e-06] | |
model.lm_head.weight.data.max().item()=1.2551071643829346 | |
step:1376/1395 train_time:186998ms step_avg:136.89ms | |
step:1377/1395 train_time:187143ms step_avg:136.90ms | |
step:1378/1395 train_time:187285ms step_avg:136.90ms | |
step:1379/1395 train_time:187428ms step_avg:136.91ms | |
step:1380/1395 train_time:187571ms step_avg:136.91ms | |
step:1381/1395 train_time:187721ms step_avg:136.92ms | |
step:1382/1395 train_time:187865ms step_avg:136.93ms | |
step:1383/1395 train_time:188010ms step_avg:136.93ms | |
step:1384/1395 train_time:188159ms step_avg:136.94ms | |
step:1385/1395 train_time:188299ms step_avg:136.94ms | |
step:1386/1395 train_time:188440ms step_avg:136.95ms | |
step:1387/1395 train_time:188585ms step_avg:136.95ms | |
step:1388/1395 train_time:188727ms step_avg:136.96ms | |
step:1389/1395 train_time:188870ms step_avg:136.96ms | |
step:1390/1395 train_time:189015ms step_avg:136.97ms | |
step:1391/1395 train_time:189157ms step_avg:136.97ms | |
step:1392/1395 train_time:189302ms step_avg:136.98ms | |
step:1393/1395 train_time:189444ms step_avg:136.98ms | |
step:1394/1395 train_time:189585ms step_avg:136.98ms | |
step:1395/1395 train_time:189728ms step_avg:136.99ms | |
step:1395/1395 val_loss:3.2765 train_time:189804ms step_avg:137.04ms | |
abs_cdf_diff(model.lm_head.weight.data, [0.001, 0.0022, 0.0047, 0.01, 0.015, 0.022, 0.033, 0.047, 0.068, 0.1, 0.15, 0.22, 0.33, 0.47, 0.68, 1.0])=[0.004111848771572113, 0.0049520013853907585, 0.01044457033276558, 0.021908877417445183, 0.020701531320810318, 0.02898605726659298, 0.04541284963488579, 0.05752796307206154, 0.08441578596830368, 0.1197732612490654, 0.16587500274181366, 0.181040421128273, 0.16676121950149536, 0.07264366745948792, 0.014895839616656303, 0.0005476339138112962, 1.4754045878362376e-06] | |
model.lm_head.weight.data.max().item()=1.2578781843185425 | |
peak memory allocated: 37800 MiB reserved: 37894 MiB |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[project] | |
name = "modded-nanogpt" | |
version = "0.1.0" | |
description = "Add your description here" | |
readme = "README.md" | |
requires-python = "==3.12.*" | |
dependencies = [ | |
"numpy>=2.1.3", | |
"torch", | |
"pytorch-triton>=3.2.0", | |
"huggingface-hub>=0.26.2", | |
"tqdm>=4.67.0", | |
"pip>=24.3.1", | |
] | |
[tool.uv] | |
environments = [ | |
"sys_platform == 'linux'", | |
] | |
[tool.uv.sources] | |
torch = [ | |
{ url = "https://download.pytorch.org/whl/nightly/cu126/torch-2.7.0.dev20250110%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" }, | |
] | |
pytorch-triton = [ | |
{ index = "pytorch-nightly-cu126" }, | |
] | |
[[tool.uv.index]] | |
name = "pytorch-nightly-cu126" | |
url = "https://download.pytorch.org/whl/nightly/cu126" | |
explicit = true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment