Created
January 15, 2025 05:16
-
-
Save YouJiacheng/26391c78a8520b978a108428aa0a5ae5 to your computer and use it in GitHub Desktop.
3.14 minutes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
with open(sys.argv[0]) as f: | |
code = f.read() # read the code of this file ASAP, for logging | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
import time | |
import uuid | |
from dataclasses import dataclass | |
from functools import lru_cache, partial | |
from itertools import cycle, islice | |
from pathlib import Path | |
import torch | |
import torch._inductor.config as config | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from torch import Tensor, nn | |
# Use of FlexAttention contributed by @KoszarskyB | |
from torch.nn.attention.flex_attention import BlockMask, flex_attention | |
config.coordinate_descent_tuning = True | |
# ----------------------------------------------------------------------------- | |
# Custom operators | |
@torch.library.custom_op("nanogpt::mm", mutates_args=()) | |
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: | |
@torch.compile | |
def impl(x: Tensor, w: Tensor): | |
assert x.is_contiguous() and w.is_contiguous() | |
x_f8 = x.mul(x_s).to(torch.float8_e4m3fn) | |
w_f8 = w.mul(w_s).to(torch.float8_e4m3fn) | |
out = torch._scaled_mm( | |
x_f8, | |
w_f8.t(), | |
out_dtype=torch.bfloat16, | |
scale_a=x.new_tensor(1 / x_s, dtype=torch.float32), | |
scale_b=x.new_tensor(1 / w_s, dtype=torch.float32), | |
use_fast_accum=True, | |
) | |
return out, x_f8, w_f8 | |
return impl(x, w) | |
@mm_op.register_fake | |
def _(x: Tensor, w: Tensor, *_): | |
assert x.ndim == w.ndim == 2 | |
assert x.shape[1] == w.shape[1] | |
assert x.device == w.device | |
assert x.is_contiguous() and w.is_contiguous() | |
return x @ w.t(), x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) | |
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) | |
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: | |
@torch.compile | |
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): | |
assert grad.is_contiguous() | |
x_inv_s = grad.new_tensor(1 / x_s, dtype=torch.float32) | |
w_inv_s = grad.new_tensor(1 / w_s, dtype=torch.float32) | |
grad_inv_s = grad.new_tensor(1 / grad_s, dtype=torch.float32) | |
grad_f8 = grad.mul(grad_s).to(torch.float8_e5m2) | |
grad_x = torch._scaled_mm( | |
grad_f8, | |
w_f8.t().contiguous().t(), | |
out_dtype=torch.bfloat16, | |
scale_a=grad_inv_s, | |
scale_b=w_inv_s, | |
use_fast_accum=False, | |
) | |
# faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) | |
grad_w = torch._scaled_mm( | |
x_f8.t().contiguous(), | |
grad_f8.t().contiguous().t(), | |
out_dtype=torch.float32, | |
scale_a=x_inv_s, | |
scale_b=grad_inv_s, | |
use_fast_accum=False, | |
).t() | |
return grad_x, grad_w | |
return impl(g, x_f8, w_f8) | |
@mm_backward_op.register_fake | |
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): | |
return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) | |
def backward(ctx, grad_out: Tensor, *_): | |
x_f8, w_f8 = ctx.saved_tensors | |
x_s, w_s, grad_s = ctx.scales | |
grad_x, grad_w = torch.ops.nanogpt.mm_backward( | |
grad_out, x_f8, w_f8, x_s, w_s, grad_s | |
) | |
return grad_x, grad_w, None, None, None | |
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): | |
*_, x_s, w_s, grad_s = inputs | |
_, x_f8, w_f8 = output | |
ctx.save_for_backward(x_f8, w_f8) | |
ctx.scales = x_s, w_s, grad_s | |
ctx.set_materialize_grads(False) | |
mm_op.register_autograd(backward, setup_context=setup_context) | |
def lm_head(x: Tensor, w: Tensor) -> Tensor: | |
_x = x.flatten(0, -2) | |
out: Tensor = torch.ops.nanogpt.mm(_x, w, x_s=2.0, w_s=32.0, grad_s=2.0**29)[0] | |
return out.reshape(*x.shape[:-1], -1) | |
# ----------------------------------------------------------------------------- | |
# Muon optimizer | |
@torch.compile | |
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: | |
""" | |
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | |
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | |
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | |
zero even beyond the point where the iteration no longer converges all the way to one everywhere | |
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | |
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | |
performance at all relative to UV^T, where USV^T = G is the SVD. | |
""" | |
assert len(G.shape) == 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
if G.size(0) > G.size(1): | |
X = X.T | |
# # Ensure spectral norm is at most 1 | |
X = X / (X.norm() + 1e-7) | |
# Perform the NS iterations | |
for _ in range(steps): | |
A = X @ X.T | |
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng | |
X = a * X + B @ X | |
if G.size(0) > G.size(1): | |
X = X.T | |
return X | |
class Muon(torch.optim.Optimizer): | |
""" | |
Muon - MomentUm Orthogonalized by Newton-schulz | |
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- | |
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal | |
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has | |
the advantage that it can be stably run in bfloat16 on the GPU. | |
Some warnings: | |
- This optimizer assumes that all parameters passed in are 2D. | |
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D | |
parameters; those should all be optimized by a standard method (e.g., AdamW). | |
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. | |
- We believe it is unlikely to work well for training with small batch size. | |
- We believe it may not work well for finetuning pretrained models, but we haven't tested this. | |
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). | |
Arguments: | |
lr: The learning rate used by the internal SGD. | |
momentum: The momentum used by the internal SGD. | |
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) | |
ns_steps: The number of Newton-Schulz iteration steps to use. | |
""" | |
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): | |
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) | |
params: "list[Tensor]" = [*params] | |
assert all(isinstance(p, Tensor) for p in params) | |
sizes = {p.numel() for p in params} | |
def create_update_buffer(size: int): | |
b = torch.empty(world_size, size, dtype=torch.bfloat16, device="cuda") | |
return dict(update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)]) | |
param_groups = [ | |
dict(params=[p for p in params if p.numel() == size], **create_update_buffer(size)) for size in sizes] | |
super().__init__(param_groups, defaults) | |
@torch.no_grad() | |
def step(self): | |
for group in self.param_groups: | |
lr = group['lr'] | |
momentum = group['momentum'] | |
nesterov = group['nesterov'] | |
ns_steps = group['ns_steps'] | |
update_buffer = group['update_buffer'] | |
update_buffer_views: "list[Tensor]" = group['update_buffer_views'] | |
# generate weight updates in distributed fashion | |
params: "list[Tensor]" = group['params'] | |
handle = None | |
params_world = None | |
def update_prev(): | |
if params_world is None: | |
return | |
assert handle is not None | |
handle.wait() | |
for p_world, g_world in zip(params_world, update_buffer_views): | |
p_world.add_( | |
g_world.view_as(p_world), | |
alpha=-lr * max(1, p_world.size(0) / p_world.size(1)) ** 0.5, | |
) | |
for base_i in range(len(params))[::world_size]: | |
if base_i + rank < len(params): | |
p = params[base_i + rank] | |
g = p.grad | |
assert g is not None | |
state = self.state[p] | |
if 'momentum_buffer' not in state: | |
state['momentum_buffer'] = torch.zeros_like(g) | |
buf: Tensor = state['momentum_buffer'] | |
buf.lerp_(g, 1 - momentum) | |
g = g.lerp_(buf, momentum) if nesterov else buf | |
g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten() | |
else: | |
g = update_buffer_views[rank] | |
update_prev() # async all_gather instead of sync all_reduce by @YouJiacheng | |
handle = dist.all_gather_into_tensor(update_buffer, g, async_op=True) | |
params_world = params[base_i : base_i + world_size] | |
update_prev() | |
# ----------------------------------------------------------------------------- | |
# PyTorch nn.Module definitions for the GPT-2 model | |
def norm(x: Tensor, size: int = None): | |
if size is None: | |
size = x.size(-1) | |
return F.rms_norm(x.unflatten(-1, (-1, size)), (size,)).flatten(-2) | |
class CastedLinear(nn.Linear): | |
def __init__(self, in_features: int, out_features: int): | |
super().__init__(in_features, out_features, bias=False) | |
def reset_parameters(self) -> None: | |
std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) | |
bound = (3 ** 0.5) * std | |
with torch.no_grad(): | |
self.weight.uniform_(-bound, bound) | |
def forward(self, x): | |
return F.linear(x, self.weight.type_as(x)) | |
class Rotary(nn.Module): | |
def __init__(self, dim: int, max_seq_len=65536): | |
super().__init__() | |
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning) | |
angular_freq = (1 / 1024) ** torch.linspace(0.0, 1.0, steps=dim // 4, dtype=torch.float32) | |
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim // 4)]) | |
t = torch.arange(max_seq_len, dtype=torch.float32) | |
theta = torch.einsum("i, j -> ij", t, angular_freq) | |
self.cos = nn.Buffer(theta.cos(), persistent=False) | |
self.sin = nn.Buffer(theta.sin(), persistent=False) | |
def forward(self, x_BTHD: Tensor): | |
assert self.cos.size(0) >= x_BTHD.size(-3) | |
cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] | |
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) | |
y1 = x1 * cos + x2 * sin | |
y2 = x1 * (-sin) + x2 * cos | |
return torch.cat((y1, y2), 3).type_as(x_BTHD) | |
class CausalSelfAttention(nn.Module): | |
def __init__(self, dim: int, num_heads: int): | |
super().__init__() | |
assert dim % num_heads == 0 | |
self.num_heads = num_heads | |
self.c_q = CastedLinear(dim, dim) | |
self.c_k = CastedLinear(dim, dim) | |
self.c_v = CastedLinear(dim, dim) | |
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) | |
self.rotary = Rotary(dim // num_heads) # dim // num_heads = head_dim | |
self.c_proj = CastedLinear(dim, dim) | |
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 | |
def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask): | |
B, T = x.size(0), x.size(1) # batch size, sequence length | |
assert B == 1, 'Must use batch size = 1 for FlexAttention' | |
q = self.c_q(x).view(B, T, self.num_heads, -1) | |
k = self.c_k(x).view(B, T, self.num_heads, -1) | |
v = self.c_v(x).view(B, T, self.num_heads, -1) | |
if ve is None: # skip mid-layers token value embeddings by @YouJiacheng | |
v = self.lambdas[0] * v | |
else: | |
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 | |
q, k = norm(q), norm(k) # QK norm @Grad62304977 | |
q, k = self.rotary(q), self.rotary(k) | |
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask) | |
y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side | |
y = self.c_proj(y) | |
return y | |
class MLP(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.c_fc = CastedLinear(dim, 4 * dim) | |
self.c_proj = CastedLinear(4 * dim, dim) | |
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 | |
def forward(self, x): | |
x = self.c_fc(x) | |
x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 | |
x = self.c_proj(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, model_dim: int, num_heads: int, layer_idx: int): | |
super().__init__() | |
# skip attention of blocks.7 (the 8th layer) by @YouJiacheng | |
self.attn = CausalSelfAttention(model_dim, num_heads) if layer_idx != 7 else None | |
self.mlp = MLP(model_dim) | |
self.lambdas = nn.Parameter(torch.tensor([1., 0.])) | |
def forward(self, x, vi, x0, block_mask): | |
x = self.lambdas[0] * x + self.lambdas[1] * x0 | |
if self.attn is not None: | |
x = x + self.attn(norm(x), vi, block_mask) | |
x = x + self.mlp(norm(x)) | |
return x | |
class ValueEmbedding(nn.Module): | |
def __init__(self, num_embeddings: int, embedding_dim: int): | |
super().__init__() | |
self.__setattr__ | |
self.embed = nn.ModuleList([nn.Embedding(num_embeddings, embedding_dim) for _ in range(3)]) | |
def forward(self, input_seq) -> "list[Tensor | None]": | |
ve = [emb(input_seq) for emb in self.embed] | |
# 012 ... 012 structure on token value embeddings by @YouJiacheng, improved upon @leloykun's U-net structure | |
ve = [ve[0], ve[1], ve[2], None, None, None, None, None, None, ve[0], ve[1], ve[2]] | |
return ve | |
# ----------------------------------------------------------------------------- | |
# The main GPT-2 model | |
def next_multiple_of_n(v: float | int, *, n: int): | |
return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) | |
class GPT(nn.Module): | |
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int): | |
super().__init__() | |
self.embed = nn.Embedding(vocab_size, model_dim) | |
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning | |
self.value_embeds = ValueEmbedding(vocab_size, model_dim) | |
self.blocks = nn.ModuleList([Block(model_dim, num_heads, layer_idx) for layer_idx in range(num_layers)]) | |
# U-net design by @brendanh0gan | |
self.num_encoder_layers = num_layers // 2 # Half of the layers for encoder | |
self.num_decoder_layers = num_layers - self.num_encoder_layers # Remaining for decoder | |
# Add learnable skip connection weights for decoder layers | |
self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) | |
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. | |
# suggested to me by @Grad62304977. this originates from Karpathy's experiments. | |
self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128)) | |
self.lm_head.weight.detach().zero_() # @Grad62304977 | |
def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): | |
BLOCK_SIZE = 128 | |
assert input_seq.ndim == 1 | |
assert len(input_seq) % BLOCK_SIZE == 0 | |
NUM_BLOCKS = len(input_seq) // BLOCK_SIZE | |
docs = (input_seq == 50256).cumsum(0) | |
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() | |
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() | |
def document_causal(b, h, q_idx, kv_idx): | |
causal_mask = q_idx >= kv_idx | |
document_mask = docs[q_idx] == docs[kv_idx] | |
return causal_mask & document_mask | |
def dense_to_ordered(dense_mask: Tensor): | |
num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) | |
indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(torch.int32) | |
return num_blocks[None, None].contiguous(), indices[None, None].contiguous() | |
# manual block mask creation by @YouJiacheng | |
def create_doc_swc_block_mask(sliding_window_num_blocks: Tensor): | |
kv_idx = block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") | |
q_idx = block_idx[:, None] | |
causal_bm = q_idx >= kv_idx | |
causal_full_bm = q_idx > kv_idx | |
window_bm = q_idx - kv_idx < sliding_window_num_blocks | |
window_full_bm = window_bm # block-wise sliding window by @YouJiacheng | |
# document_bm = (docs_low[q_idx] <= docs_high[kv_idx]) & (docs_low[kv_idx] <= docs_high[q_idx]) | |
document_bm = (docs_low[:, None] <= docs_high) & (docs_low <= docs_high[:, None]) | |
document_full_bm = (docs_low[:, None] == docs_high) & (docs_low == docs_high[:, None]) | |
nonzero_bm = causal_bm & window_bm & document_bm | |
full_bm = causal_full_bm & window_full_bm & document_full_bm | |
kv_num_blocks, kv_indices = dense_to_ordered(nonzero_bm & ~full_bm) | |
full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) | |
return BlockMask.from_kv_blocks( | |
kv_num_blocks, | |
kv_indices, | |
full_kv_num_blocks, | |
full_kv_indices, | |
BLOCK_SIZE=BLOCK_SIZE, | |
mask_mod=document_causal, | |
) | |
block_mask = create_doc_swc_block_mask(sliding_window_num_blocks) | |
x = x0 = norm(self.embed(input_seq)[None]) # @Grad62304977 | |
ve = self.value_embeds(input_seq) | |
ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:] | |
assert len(ve_enc) == self.num_encoder_layers and len(ve_dec) == self.num_decoder_layers | |
# Store outputs for U-Net skip connections | |
skip_connections = [] | |
# Encoder pass - process only the first half of the blocks | |
for i in range(self.num_encoder_layers): | |
x = self.blocks[i](x, ve_enc[i], x0, block_mask) | |
skip_connections.append(x) | |
# Decoder pass - process the remaining blocks with weighted skip connections | |
for i in range(self.num_decoder_layers): | |
x = x + self.skip_weights[i] * skip_connections.pop() | |
x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_mask) | |
x = norm(x) | |
logits = lm_head(x, self.lm_head.weight) if self.training else self.lm_head(x) | |
logits = 30 * torch.sigmoid(logits.float() / 7.5) # @Grad62304977 | |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1)) | |
return loss | |
# ----------------------------------------------------------------------------- | |
# Our own simple Distributed Data Loader | |
def _load_data_shard(file: Path): | |
header = torch.from_file(f"{file}", False, 256, dtype=torch.int32) # header is 256 int32 | |
assert header[0] == 20240520, 'magic number mismatch in the data .bin file' | |
assert header[1] == 1, 'unsupported version' | |
num_tokens = int(header[2]) # number of tokens (claimed) | |
with file.open('rb', buffering=0) as f: | |
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng | |
f.seek(256 * 4) | |
nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng | |
assert nbytes == 2 * num_tokens, 'number of tokens read does not match header?' | |
return tokens | |
def distributed_data(filename_pattern: str, batch_size: int): | |
files = sorted(Path.cwd().glob(filename_pattern)) | |
assert batch_size % world_size == 0 | |
local_batch_size = batch_size // world_size | |
file_iter = cycle(files) | |
tokens, pos = _load_data_shard(next(file_iter)), 0 | |
while True: | |
if pos + batch_size + 1 >= len(tokens): | |
tokens, pos = _load_data_shard(next(file_iter)), 0 | |
buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1] | |
inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; | |
targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. | |
pos += batch_size | |
yield inputs, targets | |
# ----------------------------------------------------------------------------- | |
# int main | |
@dataclass | |
class Hyperparameters: | |
# data | |
train_bin = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on | |
val_bin = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on | |
val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons | |
# optimization | |
batch_size = 8*64*1024 # batch size in tokens | |
num_iterations = 1395 # number of iterations to run | |
cooldown_frac = 0.4 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule | |
# evaluation and logging | |
val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end | |
# implementation | |
seq_len = 64*1024 # FlexAttention sequence length | |
save_checkpoint = False | |
args = Hyperparameters() | |
# torchrun sets these env variables | |
rank = int(os.environ['RANK']) | |
world_size = int(os.environ['WORLD_SIZE']) | |
assert torch.cuda.is_available() | |
device = torch.device('cuda', int(os.environ['LOCAL_RANK'])) | |
torch.cuda.set_device(device) | |
dist.init_process_group(backend='nccl', device_id=device) | |
dist.barrier() | |
master_process = (rank == 0) # this process will do logging, checkpointing etc. | |
# begin logging | |
def print0(s, console=False): ... | |
if master_process: | |
run_id = uuid.uuid4() | |
(logs_dir := Path("logs")).mkdir(exist_ok=True) | |
logfile = logs_dir / f"{run_id}.txt" | |
print(logfile.stem) | |
def print0(s, console=False): | |
with logfile.open("a") as f: | |
# if console: | |
# print(s) | |
print(s, file=f) | |
# begin by printing this file (the Python code) | |
print0(code) | |
print0('='*100) | |
# log information about the hardware/software environment this is running on | |
print0(f'Running Python {sys.version}') | |
print0(f'Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}') | |
def nvidia_smi(): | |
import subprocess # avoid top level import | |
return subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout | |
print0(nvidia_smi()) | |
print0('='*100) | |
# load data | |
train_loader = distributed_data(args.train_bin, args.batch_size) | |
val_loader = partial(distributed_data, args.val_bin) | |
model = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768).cuda() | |
for m in model.modules(): | |
if isinstance(m, nn.Embedding): | |
m.bfloat16() | |
for param in model.parameters(): | |
dist.broadcast(param.detach(), 0) | |
# collect the parameters to optimize | |
hidden_matrix_params = [p for p in model.blocks.parameters() if p.ndim == 2] | |
embed_params = [model.embed.weight, *model.value_embeds.parameters()] | |
scalar_params = [p for p in model.parameters() if p.ndim < 2] | |
head_params = [model.lm_head.weight] | |
# init the optimizer(s) | |
adam_params = [dict(params=head_params, lr=0.008), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] | |
optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), fused=True) | |
optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95) | |
optimizers = [optimizer1, optimizer2] | |
# learning rate schedule: stable then decay | |
def get_lr(it: int): | |
t = 1 - it / args.num_iterations # time remaining in training | |
assert 1 >= t >= 0 | |
w = min(t / args.cooldown_frac, 1.0) # 1 -> 0 | |
return w * 1.0 + (1 - w) * 0.1 | |
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] | |
@lru_cache(1) | |
def sw_num_blks(window_size: int): | |
return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) | |
model: nn.Module = torch.compile(model) | |
training_time_ms = 0 | |
# start the clock | |
torch.cuda.synchronize() | |
t0 = time.perf_counter() | |
# begin training | |
train_steps = args.num_iterations | |
for step in range(train_steps + 1): | |
last_step = (step == train_steps) | |
# This effectively ignores timing first 10 steps, which are slower for weird reasons. | |
# Alternately, and slightly more correctly in terms of benchmarking, we could do 10 | |
# steps with dummy data first, and then re-initialize the model and reset the loader. | |
if step == 10: | |
training_time_ms = 0 | |
t0 = time.perf_counter() | |
timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val | |
# Linearly increase the block-wise sliding window size over training 128 -> 1792: | |
# increase by @fernbear.bsky.social; block-wise by @YouJiacheng | |
window_size = next_multiple_of_n(1728 * step / train_steps, n=128) | |
# --------------- VALIDATION SECTION ----------------- | |
if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): | |
# stop the clock | |
torch.cuda.synchronize() | |
training_time_ms += 1000 * (time.perf_counter() - t0) | |
model.eval() | |
val_bs = world_size * args.seq_len | |
assert args.val_tokens % val_bs == 0 | |
val_steps = args.val_tokens // val_bs | |
with torch.no_grad(): | |
val_loss = sum(model(x, y, sw_num_blks(window_size)) for x, y in islice(val_loader(val_bs), val_steps)) / val_steps | |
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) | |
print0(f'step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms', console=True) | |
model.train() | |
# start the clock again | |
torch.cuda.synchronize() | |
t0 = time.perf_counter() | |
if last_step: | |
if master_process and args.save_checkpoint: | |
log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) | |
os.makedirs(f'logs/{run_id}', exist_ok=True) | |
torch.save(log, f'logs/{run_id}/state_step{step:06d}.pt') | |
# the last step only has the validation loop, so break to avoid training | |
break | |
# --------------- TRAINING SECTION BEGIN ----------------- | |
inputs, targets = next(train_loader) | |
for input_seq, target_seq in zip(inputs.split(args.seq_len), targets.split(args.seq_len)): | |
model(input_seq, target_seq, sw_num_blks(window_size)).backward() | |
for param in model.parameters(): | |
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) | |
# momentum warmup for Muon | |
frac = min(step / 300, 1) | |
for group in optimizer2.param_groups: | |
group['momentum'] = (1 - frac) * 0.85 + frac * 0.95 | |
# step the optimizers and schedulers | |
for opt, sched in zip(optimizers, schedulers): | |
opt.step() | |
sched.step() | |
# null the gradients | |
model.zero_grad(set_to_none=True) | |
# --------------- TRAINING SECTION END ------------------- | |
# everything that follows now is just diagnostics, prints, logging, etc. | |
approx_time = training_time_ms + 1000 * (time.perf_counter() - t0) | |
print0(f'step:{step+1}/{train_steps} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms', console=True) | |
print0( | |
f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " | |
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" | |
) | |
dist.destroy_process_group() | |
==================================================================================================== | |
Running Python 3.12.8 (main, Dec 19 2024, 14:33:20) [Clang 18.1.8 ] | |
Running PyTorch 2.7.0.dev20250110+cu126 compiled for CUDA 12.6 | |
Wed Jan 15 04:26:12 2025 | |
+---------------------------------------------------------------------------------------+ | |
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.6 | | |
|-----------------------------------------+----------------------+----------------------+ | |
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | |
| | | MIG M. | | |
|=========================================+======================+======================| | |
| 0 NVIDIA H100 80GB HBM3 On | 00000000:65:02.0 Off | 0 | | |
| N/A 37C P0 120W / 700W | 7092MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 1 NVIDIA H100 80GB HBM3 On | 00000000:67:02.0 Off | 0 | | |
| N/A 45C P0 129W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 2 NVIDIA H100 80GB HBM3 On | 00000000:69:02.0 Off | 0 | | |
| N/A 45C P0 122W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 3 NVIDIA H100 80GB HBM3 On | 00000000:6B:02.0 Off | 0 | | |
| N/A 39C P0 118W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 4 NVIDIA H100 80GB HBM3 On | 00000000:6F:02.0 Off | 0 | | |
| N/A 38C P0 117W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 5 NVIDIA H100 80GB HBM3 On | 00000000:71:02.0 Off | 0 | | |
| N/A 45C P0 121W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 6 NVIDIA H100 80GB HBM3 On | 00000000:73:02.0 Off | 0 | | |
| N/A 45C P0 127W / 700W | 3459MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
| 7 NVIDIA H100 80GB HBM3 On | 00000000:75:02.0 Off | 0 | | |
| N/A 38C P0 124W / 700W | 3219MiB / 81559MiB | 0% Default | | |
| | | Disabled | | |
+-----------------------------------------+----------------------+----------------------+ | |
+---------------------------------------------------------------------------------------+ | |
| Processes: | | |
| GPU GI CI PID Type Process name GPU Memory | | |
| ID ID Usage | | |
|=======================================================================================| | |
+---------------------------------------------------------------------------------------+ | |
==================================================================================================== | |
step:0/1395 val_loss:10.8258 train_time:0ms step_avg:nanms | |
step:1/1395 train_time:18453ms step_avg:nanms | |
step:2/1395 train_time:18483ms step_avg:nanms | |
step:3/1395 train_time:18595ms step_avg:nanms | |
step:4/1395 train_time:18718ms step_avg:nanms | |
step:5/1395 train_time:18841ms step_avg:nanms | |
step:6/1395 train_time:18964ms step_avg:nanms | |
step:7/1395 train_time:19088ms step_avg:nanms | |
step:8/1395 train_time:19212ms step_avg:nanms | |
step:9/1395 train_time:19336ms step_avg:nanms | |
step:10/1395 train_time:19460ms step_avg:nanms | |
step:11/1395 train_time:125ms step_avg:nanms | |
step:12/1395 train_time:249ms step_avg:nanms | |
step:13/1395 train_time:374ms step_avg:124.73ms | |
step:14/1395 train_time:498ms step_avg:124.47ms | |
step:15/1395 train_time:622ms step_avg:124.49ms | |
step:16/1395 train_time:747ms step_avg:124.53ms | |
step:17/1395 train_time:871ms step_avg:124.44ms | |
step:18/1395 train_time:996ms step_avg:124.50ms | |
step:19/1395 train_time:1120ms step_avg:124.46ms | |
step:20/1395 train_time:1245ms step_avg:124.47ms | |
step:21/1395 train_time:1372ms step_avg:124.69ms | |
step:22/1395 train_time:1494ms step_avg:124.51ms | |
step:23/1395 train_time:1618ms step_avg:124.49ms | |
step:24/1395 train_time:1743ms step_avg:124.50ms | |
step:25/1395 train_time:1868ms step_avg:124.51ms | |
step:26/1395 train_time:1992ms step_avg:124.48ms | |
step:27/1395 train_time:2116ms step_avg:124.50ms | |
step:28/1395 train_time:2243ms step_avg:124.59ms | |
step:29/1395 train_time:2367ms step_avg:124.60ms | |
step:30/1395 train_time:2491ms step_avg:124.57ms | |
step:31/1395 train_time:2616ms step_avg:124.56ms | |
step:32/1395 train_time:2740ms step_avg:124.54ms | |
step:33/1395 train_time:2867ms step_avg:124.65ms | |
step:34/1395 train_time:2991ms step_avg:124.64ms | |
step:35/1395 train_time:3115ms step_avg:124.60ms | |
step:36/1395 train_time:3239ms step_avg:124.58ms | |
step:37/1395 train_time:3364ms step_avg:124.59ms | |
step:38/1395 train_time:3489ms step_avg:124.59ms | |
step:39/1395 train_time:3613ms step_avg:124.58ms | |
step:40/1395 train_time:3737ms step_avg:124.57ms | |
step:41/1395 train_time:3863ms step_avg:124.62ms | |
step:42/1395 train_time:3989ms step_avg:124.65ms | |
step:43/1395 train_time:4113ms step_avg:124.64ms | |
step:44/1395 train_time:4237ms step_avg:124.63ms | |
step:45/1395 train_time:4362ms step_avg:124.63ms | |
step:46/1395 train_time:4486ms step_avg:124.61ms | |
step:47/1395 train_time:4610ms step_avg:124.60ms | |
step:48/1395 train_time:4736ms step_avg:124.62ms | |
step:49/1395 train_time:4859ms step_avg:124.60ms | |
step:50/1395 train_time:4985ms step_avg:124.61ms | |
step:51/1395 train_time:5109ms step_avg:124.61ms | |
step:52/1395 train_time:5236ms step_avg:124.66ms | |
step:53/1395 train_time:5361ms step_avg:124.66ms | |
step:54/1395 train_time:5486ms step_avg:124.67ms | |
step:55/1395 train_time:5613ms step_avg:124.73ms | |
step:56/1395 train_time:5738ms step_avg:124.74ms | |
step:57/1395 train_time:5863ms step_avg:124.74ms | |
step:58/1395 train_time:5987ms step_avg:124.73ms | |
step:59/1395 train_time:6111ms step_avg:124.72ms | |
step:60/1395 train_time:6237ms step_avg:124.75ms | |
step:61/1395 train_time:6362ms step_avg:124.74ms | |
step:62/1395 train_time:6485ms step_avg:124.72ms | |
step:63/1395 train_time:6611ms step_avg:124.73ms | |
step:64/1395 train_time:6737ms step_avg:124.76ms | |
step:65/1395 train_time:6861ms step_avg:124.74ms | |
step:66/1395 train_time:6985ms step_avg:124.73ms | |
step:67/1395 train_time:7108ms step_avg:124.71ms | |
step:68/1395 train_time:7234ms step_avg:124.72ms | |
step:69/1395 train_time:7360ms step_avg:124.75ms | |
step:70/1395 train_time:7484ms step_avg:124.73ms | |
step:71/1395 train_time:7610ms step_avg:124.75ms | |
step:72/1395 train_time:7735ms step_avg:124.75ms | |
step:73/1395 train_time:7860ms step_avg:124.77ms | |
step:74/1395 train_time:7985ms step_avg:124.76ms | |
step:75/1395 train_time:8111ms step_avg:124.78ms | |
step:76/1395 train_time:8238ms step_avg:124.82ms | |
step:77/1395 train_time:8363ms step_avg:124.82ms | |
step:78/1395 train_time:8487ms step_avg:124.81ms | |
step:79/1395 train_time:8613ms step_avg:124.83ms | |
step:80/1395 train_time:8738ms step_avg:124.83ms | |
step:81/1395 train_time:8863ms step_avg:124.83ms | |
step:82/1395 train_time:8988ms step_avg:124.83ms | |
step:83/1395 train_time:9112ms step_avg:124.82ms | |
step:84/1395 train_time:9239ms step_avg:124.85ms | |
step:85/1395 train_time:9364ms step_avg:124.85ms | |
step:86/1395 train_time:9487ms step_avg:124.83ms | |
step:87/1395 train_time:9612ms step_avg:124.83ms | |
step:88/1395 train_time:9736ms step_avg:124.83ms | |
step:89/1395 train_time:9862ms step_avg:124.84ms | |
step:90/1395 train_time:9986ms step_avg:124.82ms | |
step:91/1395 train_time:10110ms step_avg:124.82ms | |
step:92/1395 train_time:10235ms step_avg:124.82ms | |
step:93/1395 train_time:10359ms step_avg:124.81ms | |
step:94/1395 train_time:10484ms step_avg:124.81ms | |
step:95/1395 train_time:10609ms step_avg:124.81ms | |
step:96/1395 train_time:10734ms step_avg:124.82ms | |
step:97/1395 train_time:10859ms step_avg:124.82ms | |
step:98/1395 train_time:10984ms step_avg:124.81ms | |
step:99/1395 train_time:11110ms step_avg:124.83ms | |
step:100/1395 train_time:11236ms step_avg:124.84ms | |
step:101/1395 train_time:11359ms step_avg:124.83ms | |
step:102/1395 train_time:11483ms step_avg:124.81ms | |
step:103/1395 train_time:11607ms step_avg:124.80ms | |
step:104/1395 train_time:11731ms step_avg:124.80ms | |
step:105/1395 train_time:11859ms step_avg:124.83ms | |
step:106/1395 train_time:11984ms step_avg:124.83ms | |
step:107/1395 train_time:12110ms step_avg:124.85ms | |
step:108/1395 train_time:12240ms step_avg:124.90ms | |
step:109/1395 train_time:12369ms step_avg:124.94ms | |
step:110/1395 train_time:12496ms step_avg:124.96ms | |
step:111/1395 train_time:12622ms step_avg:124.97ms | |
step:112/1395 train_time:12751ms step_avg:125.01ms | |
step:113/1395 train_time:12880ms step_avg:125.05ms | |
step:114/1395 train_time:13006ms step_avg:125.06ms | |
step:115/1395 train_time:13134ms step_avg:125.08ms | |
step:116/1395 train_time:13261ms step_avg:125.10ms | |
step:117/1395 train_time:13387ms step_avg:125.11ms | |
step:118/1395 train_time:13514ms step_avg:125.13ms | |
step:119/1395 train_time:13641ms step_avg:125.15ms | |
step:120/1395 train_time:13769ms step_avg:125.17ms | |
step:121/1395 train_time:13897ms step_avg:125.20ms | |
step:122/1395 train_time:14023ms step_avg:125.20ms | |
step:123/1395 train_time:14151ms step_avg:125.23ms | |
step:124/1395 train_time:14280ms step_avg:125.26ms | |
step:125/1395 train_time:14407ms step_avg:125.28ms | |
step:125/1395 val_loss:4.3737 train_time:14509ms step_avg:126.16ms | |
step:126/1395 train_time:14536ms step_avg:125.31ms | |
step:127/1395 train_time:14676ms step_avg:125.44ms | |
step:128/1395 train_time:14806ms step_avg:125.48ms | |
step:129/1395 train_time:14932ms step_avg:125.48ms | |
step:130/1395 train_time:15058ms step_avg:125.48ms | |
step:131/1395 train_time:15184ms step_avg:125.49ms | |
step:132/1395 train_time:15310ms step_avg:125.49ms | |
step:133/1395 train_time:15436ms step_avg:125.50ms | |
step:134/1395 train_time:15563ms step_avg:125.51ms | |
step:135/1395 train_time:15691ms step_avg:125.53ms | |
step:136/1395 train_time:15820ms step_avg:125.55ms | |
step:137/1395 train_time:15948ms step_avg:125.58ms | |
step:138/1395 train_time:16074ms step_avg:125.58ms | |
step:139/1395 train_time:16200ms step_avg:125.58ms | |
step:140/1395 train_time:16326ms step_avg:125.59ms | |
step:141/1395 train_time:16452ms step_avg:125.59ms | |
step:142/1395 train_time:16578ms step_avg:125.59ms | |
step:143/1395 train_time:16705ms step_avg:125.60ms | |
step:144/1395 train_time:16832ms step_avg:125.61ms | |
step:145/1395 train_time:16959ms step_avg:125.62ms | |
step:146/1395 train_time:17086ms step_avg:125.63ms | |
step:147/1395 train_time:17212ms step_avg:125.63ms | |
step:148/1395 train_time:17338ms step_avg:125.64ms | |
step:149/1395 train_time:17464ms step_avg:125.64ms | |
step:150/1395 train_time:17590ms step_avg:125.64ms | |
step:151/1395 train_time:17716ms step_avg:125.65ms | |
step:152/1395 train_time:17844ms step_avg:125.66ms | |
step:153/1395 train_time:17970ms step_avg:125.67ms | |
step:154/1395 train_time:18096ms step_avg:125.67ms | |
step:155/1395 train_time:18223ms step_avg:125.68ms | |
step:156/1395 train_time:18349ms step_avg:125.68ms | |
step:157/1395 train_time:18474ms step_avg:125.68ms | |
step:158/1395 train_time:18601ms step_avg:125.68ms | |
step:159/1395 train_time:18728ms step_avg:125.69ms | |
step:160/1395 train_time:18854ms step_avg:125.69ms | |
step:161/1395 train_time:18981ms step_avg:125.70ms | |
step:162/1395 train_time:19109ms step_avg:125.72ms | |
step:163/1395 train_time:19238ms step_avg:125.74ms | |
step:164/1395 train_time:19364ms step_avg:125.74ms | |
step:165/1395 train_time:19490ms step_avg:125.74ms | |
step:166/1395 train_time:19616ms step_avg:125.74ms | |
step:167/1395 train_time:19742ms step_avg:125.75ms | |
step:168/1395 train_time:19869ms step_avg:125.75ms | |
step:169/1395 train_time:19995ms step_avg:125.76ms | |
step:170/1395 train_time:20123ms step_avg:125.77ms | |
step:171/1395 train_time:20251ms step_avg:125.78ms | |
step:172/1395 train_time:20378ms step_avg:125.79ms | |
step:173/1395 train_time:20507ms step_avg:125.81ms | |
step:174/1395 train_time:20635ms step_avg:125.83ms | |
step:175/1395 train_time:20760ms step_avg:125.82ms | |
step:176/1395 train_time:20887ms step_avg:125.82ms | |
step:177/1395 train_time:21013ms step_avg:125.83ms | |
step:178/1395 train_time:21140ms step_avg:125.83ms | |
step:179/1395 train_time:21267ms step_avg:125.84ms | |
step:180/1395 train_time:21394ms step_avg:125.85ms | |
step:181/1395 train_time:21522ms step_avg:125.86ms | |
step:182/1395 train_time:21650ms step_avg:125.87ms | |
step:183/1395 train_time:21775ms step_avg:125.86ms | |
step:184/1395 train_time:21900ms step_avg:125.86ms | |
step:185/1395 train_time:22027ms step_avg:125.87ms | |
step:186/1395 train_time:22153ms step_avg:125.87ms | |
step:187/1395 train_time:22279ms step_avg:125.87ms | |
step:188/1395 train_time:22408ms step_avg:125.89ms | |
step:189/1395 train_time:22535ms step_avg:125.89ms | |
step:190/1395 train_time:22662ms step_avg:125.90ms | |
step:191/1395 train_time:22789ms step_avg:125.91ms | |
step:192/1395 train_time:22916ms step_avg:125.91ms | |
step:193/1395 train_time:23044ms step_avg:125.92ms | |
step:194/1395 train_time:23168ms step_avg:125.91ms | |
step:195/1395 train_time:23296ms step_avg:125.92ms | |
step:196/1395 train_time:23424ms step_avg:125.93ms | |
step:197/1395 train_time:23549ms step_avg:125.93ms | |
step:198/1395 train_time:23677ms step_avg:125.94ms | |
step:199/1395 train_time:23804ms step_avg:125.95ms | |
step:200/1395 train_time:23931ms step_avg:125.95ms | |
step:201/1395 train_time:24057ms step_avg:125.95ms | |
step:202/1395 train_time:24184ms step_avg:125.96ms | |
step:203/1395 train_time:24311ms step_avg:125.96ms | |
step:204/1395 train_time:24438ms step_avg:125.97ms | |
step:205/1395 train_time:24566ms step_avg:125.98ms | |
step:206/1395 train_time:24692ms step_avg:125.98ms | |
step:207/1395 train_time:24820ms step_avg:125.99ms | |
step:208/1395 train_time:24949ms step_avg:126.01ms | |
step:209/1395 train_time:25079ms step_avg:126.02ms | |
step:210/1395 train_time:25209ms step_avg:126.04ms | |
step:211/1395 train_time:25337ms step_avg:126.05ms | |
step:212/1395 train_time:25465ms step_avg:126.07ms | |
step:213/1395 train_time:25594ms step_avg:126.08ms | |
step:214/1395 train_time:25723ms step_avg:126.09ms | |
step:215/1395 train_time:25852ms step_avg:126.11ms | |
step:216/1395 train_time:25981ms step_avg:126.12ms | |
step:217/1395 train_time:26111ms step_avg:126.14ms | |
step:218/1395 train_time:26240ms step_avg:126.15ms | |
step:219/1395 train_time:26369ms step_avg:126.17ms | |
step:220/1395 train_time:26498ms step_avg:126.18ms | |
step:221/1395 train_time:26629ms step_avg:126.20ms | |
step:222/1395 train_time:26758ms step_avg:126.22ms | |
step:223/1395 train_time:26888ms step_avg:126.24ms | |
step:224/1395 train_time:27018ms step_avg:126.25ms | |
step:225/1395 train_time:27149ms step_avg:126.27ms | |
step:226/1395 train_time:27277ms step_avg:126.28ms | |
step:227/1395 train_time:27407ms step_avg:126.30ms | |
step:228/1395 train_time:27536ms step_avg:126.31ms | |
step:229/1395 train_time:27665ms step_avg:126.32ms | |
step:230/1395 train_time:27792ms step_avg:126.33ms | |
step:231/1395 train_time:27921ms step_avg:126.34ms | |
step:232/1395 train_time:28052ms step_avg:126.36ms | |
step:233/1395 train_time:28181ms step_avg:126.37ms | |
step:234/1395 train_time:28311ms step_avg:126.39ms | |
step:235/1395 train_time:28441ms step_avg:126.41ms | |
step:236/1395 train_time:28571ms step_avg:126.42ms | |
step:237/1395 train_time:28700ms step_avg:126.43ms | |
step:238/1395 train_time:28830ms step_avg:126.45ms | |
step:239/1395 train_time:28959ms step_avg:126.46ms | |
step:240/1395 train_time:29089ms step_avg:126.47ms | |
step:241/1395 train_time:29218ms step_avg:126.49ms | |
step:242/1395 train_time:29349ms step_avg:126.51ms | |
step:243/1395 train_time:29479ms step_avg:126.52ms | |
step:244/1395 train_time:29609ms step_avg:126.53ms | |
step:245/1395 train_time:29738ms step_avg:126.54ms | |
step:246/1395 train_time:29867ms step_avg:126.55ms | |
step:247/1395 train_time:29995ms step_avg:126.56ms | |
step:248/1395 train_time:30125ms step_avg:126.58ms | |
step:249/1395 train_time:30253ms step_avg:126.58ms | |
step:250/1395 train_time:30382ms step_avg:126.59ms | |
step:250/1395 val_loss:3.9533 train_time:30487ms step_avg:127.03ms | |
step:251/1395 train_time:30514ms step_avg:126.61ms | |
step:252/1395 train_time:30653ms step_avg:126.67ms | |
step:253/1395 train_time:30783ms step_avg:126.68ms | |
step:254/1395 train_time:30912ms step_avg:126.69ms | |
step:255/1395 train_time:31040ms step_avg:126.69ms | |
step:256/1395 train_time:31168ms step_avg:126.70ms | |
step:257/1395 train_time:31297ms step_avg:126.71ms | |
step:258/1395 train_time:31427ms step_avg:126.72ms | |
step:259/1395 train_time:31558ms step_avg:126.74ms | |
step:260/1395 train_time:31689ms step_avg:126.76ms | |
step:261/1395 train_time:31818ms step_avg:126.76ms | |
step:262/1395 train_time:31947ms step_avg:126.77ms | |
step:263/1395 train_time:32076ms step_avg:126.78ms | |
step:264/1395 train_time:32205ms step_avg:126.79ms | |
step:265/1395 train_time:32334ms step_avg:126.80ms | |
step:266/1395 train_time:32464ms step_avg:126.81ms | |
step:267/1395 train_time:32593ms step_avg:126.82ms | |
step:268/1395 train_time:32724ms step_avg:126.84ms | |
step:269/1395 train_time:32853ms step_avg:126.84ms | |
step:270/1395 train_time:32982ms step_avg:126.85ms | |
step:271/1395 train_time:33110ms step_avg:126.86ms | |
step:272/1395 train_time:33238ms step_avg:126.86ms | |
step:273/1395 train_time:33367ms step_avg:126.87ms | |
step:274/1395 train_time:33496ms step_avg:126.88ms | |
step:275/1395 train_time:33628ms step_avg:126.90ms | |
step:276/1395 train_time:33757ms step_avg:126.91ms | |
step:277/1395 train_time:33887ms step_avg:126.92ms | |
step:278/1395 train_time:34016ms step_avg:126.93ms | |
step:279/1395 train_time:34146ms step_avg:126.94ms | |
step:280/1395 train_time:34274ms step_avg:126.94ms | |
step:281/1395 train_time:34404ms step_avg:126.95ms | |
step:282/1395 train_time:34532ms step_avg:126.96ms | |
step:283/1395 train_time:34661ms step_avg:126.96ms | |
step:284/1395 train_time:34791ms step_avg:126.97ms | |
step:285/1395 train_time:34919ms step_avg:126.98ms | |
step:286/1395 train_time:35048ms step_avg:126.99ms | |
step:287/1395 train_time:35177ms step_avg:126.99ms | |
step:288/1395 train_time:35307ms step_avg:127.00ms | |
step:289/1395 train_time:35436ms step_avg:127.01ms | |
step:290/1395 train_time:35565ms step_avg:127.02ms | |
step:291/1395 train_time:35694ms step_avg:127.02ms | |
step:292/1395 train_time:35823ms step_avg:127.03ms | |
step:293/1395 train_time:35951ms step_avg:127.03ms | |
step:294/1395 train_time:36081ms step_avg:127.05ms | |
step:295/1395 train_time:36210ms step_avg:127.05ms | |
step:296/1395 train_time:36339ms step_avg:127.06ms | |
step:297/1395 train_time:36468ms step_avg:127.07ms | |
step:298/1395 train_time:36598ms step_avg:127.08ms | |
step:299/1395 train_time:36729ms step_avg:127.09ms | |
step:300/1395 train_time:36858ms step_avg:127.10ms | |
step:301/1395 train_time:36987ms step_avg:127.10ms | |
step:302/1395 train_time:37116ms step_avg:127.11ms | |
step:303/1395 train_time:37245ms step_avg:127.12ms | |
step:304/1395 train_time:37373ms step_avg:127.12ms | |
step:305/1395 train_time:37503ms step_avg:127.13ms | |
step:306/1395 train_time:37631ms step_avg:127.13ms | |
step:307/1395 train_time:37761ms step_avg:127.14ms | |
step:308/1395 train_time:37890ms step_avg:127.15ms | |
step:309/1395 train_time:38020ms step_avg:127.16ms | |
step:310/1395 train_time:38150ms step_avg:127.17ms | |
step:311/1395 train_time:38279ms step_avg:127.17ms | |
step:312/1395 train_time:38408ms step_avg:127.18ms | |
step:313/1395 train_time:38539ms step_avg:127.19ms | |
step:314/1395 train_time:38670ms step_avg:127.20ms | |
step:315/1395 train_time:38800ms step_avg:127.21ms | |
step:316/1395 train_time:38932ms step_avg:127.23ms | |
step:317/1395 train_time:39063ms step_avg:127.24ms | |
step:318/1395 train_time:39194ms step_avg:127.25ms | |
step:319/1395 train_time:39327ms step_avg:127.27ms | |
step:320/1395 train_time:39458ms step_avg:127.28ms | |
step:321/1395 train_time:39589ms step_avg:127.30ms | |
step:322/1395 train_time:39721ms step_avg:127.31ms | |
step:323/1395 train_time:39852ms step_avg:127.32ms | |
step:324/1395 train_time:39983ms step_avg:127.33ms | |
step:325/1395 train_time:40114ms step_avg:127.35ms | |
step:326/1395 train_time:40246ms step_avg:127.36ms | |
step:327/1395 train_time:40376ms step_avg:127.37ms | |
step:328/1395 train_time:40508ms step_avg:127.38ms | |
step:329/1395 train_time:40639ms step_avg:127.40ms | |
step:330/1395 train_time:40769ms step_avg:127.40ms | |
step:331/1395 train_time:40900ms step_avg:127.42ms | |
step:332/1395 train_time:41032ms step_avg:127.43ms | |
step:333/1395 train_time:41164ms step_avg:127.44ms | |
step:334/1395 train_time:41295ms step_avg:127.45ms | |
step:335/1395 train_time:41428ms step_avg:127.47ms | |
step:336/1395 train_time:41559ms step_avg:127.48ms | |
step:337/1395 train_time:41690ms step_avg:127.49ms | |
step:338/1395 train_time:41823ms step_avg:127.51ms | |
step:339/1395 train_time:41954ms step_avg:127.52ms | |
step:340/1395 train_time:42085ms step_avg:127.53ms | |
step:341/1395 train_time:42217ms step_avg:127.54ms | |
step:342/1395 train_time:42349ms step_avg:127.56ms | |
step:343/1395 train_time:42481ms step_avg:127.57ms | |
step:344/1395 train_time:42612ms step_avg:127.58ms | |
step:345/1395 train_time:42744ms step_avg:127.59ms | |
step:346/1395 train_time:42876ms step_avg:127.61ms | |
step:347/1395 train_time:43008ms step_avg:127.62ms | |
step:348/1395 train_time:43139ms step_avg:127.63ms | |
step:349/1395 train_time:43269ms step_avg:127.64ms | |
step:350/1395 train_time:43401ms step_avg:127.65ms | |
step:351/1395 train_time:43532ms step_avg:127.66ms | |
step:352/1395 train_time:43664ms step_avg:127.67ms | |
step:353/1395 train_time:43795ms step_avg:127.68ms | |
step:354/1395 train_time:43927ms step_avg:127.69ms | |
step:355/1395 train_time:44058ms step_avg:127.70ms | |
step:356/1395 train_time:44190ms step_avg:127.72ms | |
step:357/1395 train_time:44322ms step_avg:127.73ms | |
step:358/1395 train_time:44452ms step_avg:127.74ms | |
step:359/1395 train_time:44583ms step_avg:127.75ms | |
step:360/1395 train_time:44716ms step_avg:127.76ms | |
step:361/1395 train_time:44846ms step_avg:127.77ms | |
step:362/1395 train_time:44978ms step_avg:127.78ms | |
step:363/1395 train_time:45109ms step_avg:127.79ms | |
step:364/1395 train_time:45240ms step_avg:127.80ms | |
step:365/1395 train_time:45370ms step_avg:127.80ms | |
step:366/1395 train_time:45502ms step_avg:127.81ms | |
step:367/1395 train_time:45632ms step_avg:127.82ms | |
step:368/1395 train_time:45763ms step_avg:127.83ms | |
step:369/1395 train_time:45894ms step_avg:127.84ms | |
step:370/1395 train_time:46026ms step_avg:127.85ms | |
step:371/1395 train_time:46157ms step_avg:127.86ms | |
step:372/1395 train_time:46288ms step_avg:127.87ms | |
step:373/1395 train_time:46420ms step_avg:127.88ms | |
step:374/1395 train_time:46551ms step_avg:127.89ms | |
step:375/1395 train_time:46681ms step_avg:127.89ms | |
step:375/1395 val_loss:3.7715 train_time:46786ms step_avg:128.18ms | |
step:376/1395 train_time:46815ms step_avg:127.91ms | |
step:377/1395 train_time:46955ms step_avg:127.94ms | |
step:378/1395 train_time:47088ms step_avg:127.96ms | |
step:379/1395 train_time:47217ms step_avg:127.96ms | |
step:380/1395 train_time:47347ms step_avg:127.97ms | |
step:381/1395 train_time:47477ms step_avg:127.97ms | |
step:382/1395 train_time:47607ms step_avg:127.98ms | |
step:383/1395 train_time:47737ms step_avg:127.98ms | |
step:384/1395 train_time:47869ms step_avg:127.99ms | |
step:385/1395 train_time:47999ms step_avg:128.00ms | |
step:386/1395 train_time:48131ms step_avg:128.01ms | |
step:387/1395 train_time:48264ms step_avg:128.02ms | |
step:388/1395 train_time:48395ms step_avg:128.03ms | |
step:389/1395 train_time:48527ms step_avg:128.04ms | |
step:390/1395 train_time:48657ms step_avg:128.04ms | |
step:391/1395 train_time:48788ms step_avg:128.05ms | |
step:392/1395 train_time:48918ms step_avg:128.06ms | |
step:393/1395 train_time:49051ms step_avg:128.07ms | |
step:394/1395 train_time:49182ms step_avg:128.08ms | |
step:395/1395 train_time:49313ms step_avg:128.09ms | |
step:396/1395 train_time:49443ms step_avg:128.09ms | |
step:397/1395 train_time:49574ms step_avg:128.10ms | |
step:398/1395 train_time:49705ms step_avg:128.11ms | |
step:399/1395 train_time:49836ms step_avg:128.11ms | |
step:400/1395 train_time:49968ms step_avg:128.12ms | |
step:401/1395 train_time:50100ms step_avg:128.13ms | |
step:402/1395 train_time:50230ms step_avg:128.14ms | |
step:403/1395 train_time:50363ms step_avg:128.15ms | |
step:404/1395 train_time:50494ms step_avg:128.16ms | |
step:405/1395 train_time:50626ms step_avg:128.17ms | |
step:406/1395 train_time:50756ms step_avg:128.17ms | |
step:407/1395 train_time:50888ms step_avg:128.18ms | |
step:408/1395 train_time:51019ms step_avg:128.19ms | |
step:409/1395 train_time:51152ms step_avg:128.20ms | |
step:410/1395 train_time:51284ms step_avg:128.21ms | |
step:411/1395 train_time:51416ms step_avg:128.22ms | |
step:412/1395 train_time:51547ms step_avg:128.23ms | |
step:413/1395 train_time:51677ms step_avg:128.23ms | |
step:414/1395 train_time:51809ms step_avg:128.24ms | |
step:415/1395 train_time:51941ms step_avg:128.25ms | |
step:416/1395 train_time:52076ms step_avg:128.27ms | |
step:417/1395 train_time:52207ms step_avg:128.27ms | |
step:418/1395 train_time:52340ms step_avg:128.28ms | |
step:419/1395 train_time:52474ms step_avg:128.30ms | |
step:420/1395 train_time:52607ms step_avg:128.31ms | |
step:421/1395 train_time:52738ms step_avg:128.32ms | |
step:422/1395 train_time:52872ms step_avg:128.33ms | |
step:423/1395 train_time:53005ms step_avg:128.34ms | |
step:424/1395 train_time:53137ms step_avg:128.35ms | |
step:425/1395 train_time:53271ms step_avg:128.36ms | |
step:426/1395 train_time:53404ms step_avg:128.37ms | |
step:427/1395 train_time:53537ms step_avg:128.39ms | |
step:428/1395 train_time:53670ms step_avg:128.40ms | |
step:429/1395 train_time:53802ms step_avg:128.41ms | |
step:430/1395 train_time:53934ms step_avg:128.41ms | |
step:431/1395 train_time:54068ms step_avg:128.43ms | |
step:432/1395 train_time:54200ms step_avg:128.44ms | |
step:433/1395 train_time:54332ms step_avg:128.44ms | |
step:434/1395 train_time:54465ms step_avg:128.46ms | |
step:435/1395 train_time:54598ms step_avg:128.47ms | |
step:436/1395 train_time:54731ms step_avg:128.48ms | |
step:437/1395 train_time:54863ms step_avg:128.49ms | |
step:438/1395 train_time:54996ms step_avg:128.50ms | |
step:439/1395 train_time:55130ms step_avg:128.51ms | |
step:440/1395 train_time:55263ms step_avg:128.52ms | |
step:441/1395 train_time:55396ms step_avg:128.53ms | |
step:442/1395 train_time:55529ms step_avg:128.54ms | |
step:443/1395 train_time:55663ms step_avg:128.55ms | |
step:444/1395 train_time:55795ms step_avg:128.56ms | |
step:445/1395 train_time:55929ms step_avg:128.57ms | |
step:446/1395 train_time:56060ms step_avg:128.58ms | |
step:447/1395 train_time:56194ms step_avg:128.59ms | |
step:448/1395 train_time:56327ms step_avg:128.60ms | |
step:449/1395 train_time:56460ms step_avg:128.61ms | |
step:450/1395 train_time:56594ms step_avg:128.62ms | |
step:451/1395 train_time:56727ms step_avg:128.63ms | |
step:452/1395 train_time:56859ms step_avg:128.64ms | |
step:453/1395 train_time:56993ms step_avg:128.65ms | |
step:454/1395 train_time:57125ms step_avg:128.66ms | |
step:455/1395 train_time:57258ms step_avg:128.67ms | |
step:456/1395 train_time:57393ms step_avg:128.68ms | |
step:457/1395 train_time:57527ms step_avg:128.70ms | |
step:458/1395 train_time:57659ms step_avg:128.70ms | |
step:459/1395 train_time:57793ms step_avg:128.72ms | |
step:460/1395 train_time:57925ms step_avg:128.72ms | |
step:461/1395 train_time:58058ms step_avg:128.73ms | |
step:462/1395 train_time:58192ms step_avg:128.74ms | |
step:463/1395 train_time:58325ms step_avg:128.75ms | |
step:464/1395 train_time:58457ms step_avg:128.76ms | |
step:465/1395 train_time:58591ms step_avg:128.77ms | |
step:466/1395 train_time:58724ms step_avg:128.78ms | |
step:467/1395 train_time:58857ms step_avg:128.79ms | |
step:468/1395 train_time:58990ms step_avg:128.80ms | |
step:469/1395 train_time:59121ms step_avg:128.80ms | |
step:470/1395 train_time:59255ms step_avg:128.82ms | |
step:471/1395 train_time:59389ms step_avg:128.83ms | |
step:472/1395 train_time:59521ms step_avg:128.83ms | |
step:473/1395 train_time:59655ms step_avg:128.84ms | |
step:474/1395 train_time:59789ms step_avg:128.86ms | |
step:475/1395 train_time:59920ms step_avg:128.86ms | |
step:476/1395 train_time:60053ms step_avg:128.87ms | |
step:477/1395 train_time:60186ms step_avg:128.88ms | |
step:478/1395 train_time:60319ms step_avg:128.89ms | |
step:479/1395 train_time:60453ms step_avg:128.90ms | |
step:480/1395 train_time:60585ms step_avg:128.90ms | |
step:481/1395 train_time:60718ms step_avg:128.91ms | |
step:482/1395 train_time:60853ms step_avg:128.93ms | |
step:483/1395 train_time:60986ms step_avg:128.93ms | |
step:484/1395 train_time:61118ms step_avg:128.94ms | |
step:485/1395 train_time:61252ms step_avg:128.95ms | |
step:486/1395 train_time:61385ms step_avg:128.96ms | |
step:487/1395 train_time:61520ms step_avg:128.97ms | |
step:488/1395 train_time:61651ms step_avg:128.98ms | |
step:489/1395 train_time:61783ms step_avg:128.98ms | |
step:490/1395 train_time:61916ms step_avg:128.99ms | |
step:491/1395 train_time:62050ms step_avg:129.00ms | |
step:492/1395 train_time:62183ms step_avg:129.01ms | |
step:493/1395 train_time:62318ms step_avg:129.02ms | |
step:494/1395 train_time:62452ms step_avg:129.03ms | |
step:495/1395 train_time:62585ms step_avg:129.04ms | |
step:496/1395 train_time:62718ms step_avg:129.05ms | |
step:497/1395 train_time:62850ms step_avg:129.06ms | |
step:498/1395 train_time:62983ms step_avg:129.06ms | |
step:499/1395 train_time:63116ms step_avg:129.07ms | |
step:500/1395 train_time:63250ms step_avg:129.08ms | |
step:500/1395 val_loss:3.6559 train_time:63357ms step_avg:129.30ms | |
step:501/1395 train_time:63385ms step_avg:129.09ms | |
step:502/1395 train_time:63526ms step_avg:129.12ms | |
step:503/1395 train_time:63658ms step_avg:129.12ms | |
step:504/1395 train_time:63790ms step_avg:129.13ms | |
step:505/1395 train_time:63923ms step_avg:129.14ms | |
step:506/1395 train_time:64055ms step_avg:129.14ms | |
step:507/1395 train_time:64187ms step_avg:129.15ms | |
step:508/1395 train_time:64319ms step_avg:129.15ms | |
step:509/1395 train_time:64452ms step_avg:129.16ms | |
step:510/1395 train_time:64586ms step_avg:129.17ms | |
step:511/1395 train_time:64720ms step_avg:129.18ms | |
step:512/1395 train_time:64853ms step_avg:129.19ms | |
step:513/1395 train_time:64985ms step_avg:129.19ms | |
step:514/1395 train_time:65118ms step_avg:129.20ms | |
step:515/1395 train_time:65249ms step_avg:129.21ms | |
step:516/1395 train_time:65383ms step_avg:129.22ms | |
step:517/1395 train_time:65517ms step_avg:129.22ms | |
step:518/1395 train_time:65650ms step_avg:129.23ms | |
step:519/1395 train_time:65785ms step_avg:129.24ms | |
step:520/1395 train_time:65921ms step_avg:129.26ms | |
step:521/1395 train_time:66056ms step_avg:129.27ms | |
step:522/1395 train_time:66190ms step_avg:129.28ms | |
step:523/1395 train_time:66325ms step_avg:129.29ms | |
step:524/1395 train_time:66459ms step_avg:129.30ms | |
step:525/1395 train_time:66593ms step_avg:129.31ms | |
step:526/1395 train_time:66730ms step_avg:129.32ms | |
step:527/1395 train_time:66864ms step_avg:129.33ms | |
step:528/1395 train_time:66999ms step_avg:129.34ms | |
step:529/1395 train_time:67134ms step_avg:129.35ms | |
step:530/1395 train_time:67268ms step_avg:129.36ms | |
step:531/1395 train_time:67402ms step_avg:129.37ms | |
step:532/1395 train_time:67538ms step_avg:129.38ms | |
step:533/1395 train_time:67672ms step_avg:129.39ms | |
step:534/1395 train_time:67806ms step_avg:129.40ms | |
step:535/1395 train_time:67941ms step_avg:129.41ms | |
step:536/1395 train_time:68077ms step_avg:129.42ms | |
step:537/1395 train_time:68210ms step_avg:129.43ms | |
step:538/1395 train_time:68346ms step_avg:129.44ms | |
step:539/1395 train_time:68480ms step_avg:129.45ms | |
step:540/1395 train_time:68614ms step_avg:129.46ms | |
step:541/1395 train_time:68748ms step_avg:129.47ms | |
step:542/1395 train_time:68883ms step_avg:129.48ms | |
step:543/1395 train_time:69018ms step_avg:129.49ms | |
step:544/1395 train_time:69151ms step_avg:129.50ms | |
step:545/1395 train_time:69286ms step_avg:129.51ms | |
step:546/1395 train_time:69421ms step_avg:129.52ms | |
step:547/1395 train_time:69557ms step_avg:129.53ms | |
step:548/1395 train_time:69692ms step_avg:129.54ms | |
step:549/1395 train_time:69826ms step_avg:129.55ms | |
step:550/1395 train_time:69962ms step_avg:129.56ms | |
step:551/1395 train_time:70095ms step_avg:129.57ms | |
step:552/1395 train_time:70229ms step_avg:129.57ms | |
step:553/1395 train_time:70363ms step_avg:129.58ms | |
step:554/1395 train_time:70500ms step_avg:129.60ms | |
step:555/1395 train_time:70637ms step_avg:129.61ms | |
step:556/1395 train_time:70770ms step_avg:129.62ms | |
step:557/1395 train_time:70905ms step_avg:129.63ms | |
step:558/1395 train_time:71040ms step_avg:129.63ms | |
step:559/1395 train_time:71173ms step_avg:129.64ms | |
step:560/1395 train_time:71307ms step_avg:129.65ms | |
step:561/1395 train_time:71441ms step_avg:129.66ms | |
step:562/1395 train_time:71575ms step_avg:129.66ms | |
step:563/1395 train_time:71708ms step_avg:129.67ms | |
step:564/1395 train_time:71844ms step_avg:129.68ms | |
step:565/1395 train_time:71977ms step_avg:129.69ms | |
step:566/1395 train_time:72111ms step_avg:129.70ms | |
step:567/1395 train_time:72245ms step_avg:129.70ms | |
step:568/1395 train_time:72379ms step_avg:129.71ms | |
step:569/1395 train_time:72512ms step_avg:129.72ms | |
step:570/1395 train_time:72646ms step_avg:129.73ms | |
step:571/1395 train_time:72781ms step_avg:129.73ms | |
step:572/1395 train_time:72916ms step_avg:129.74ms | |
step:573/1395 train_time:73050ms step_avg:129.75ms | |
step:574/1395 train_time:73186ms step_avg:129.76ms | |
step:575/1395 train_time:73321ms step_avg:129.77ms | |
step:576/1395 train_time:73456ms step_avg:129.78ms | |
step:577/1395 train_time:73590ms step_avg:129.79ms | |
step:578/1395 train_time:73725ms step_avg:129.80ms | |
step:579/1395 train_time:73859ms step_avg:129.81ms | |
step:580/1395 train_time:73993ms step_avg:129.81ms | |
step:581/1395 train_time:74128ms step_avg:129.82ms | |
step:582/1395 train_time:74264ms step_avg:129.83ms | |
step:583/1395 train_time:74399ms step_avg:129.84ms | |
step:584/1395 train_time:74533ms step_avg:129.85ms | |
step:585/1395 train_time:74668ms step_avg:129.86ms | |
step:586/1395 train_time:74802ms step_avg:129.86ms | |
step:587/1395 train_time:74936ms step_avg:129.87ms | |
step:588/1395 train_time:75069ms step_avg:129.88ms | |
step:589/1395 train_time:75205ms step_avg:129.89ms | |
step:590/1395 train_time:75340ms step_avg:129.90ms | |
step:591/1395 train_time:75473ms step_avg:129.90ms | |
step:592/1395 train_time:75608ms step_avg:129.91ms | |
step:593/1395 train_time:75743ms step_avg:129.92ms | |
step:594/1395 train_time:75877ms step_avg:129.93ms | |
step:595/1395 train_time:76010ms step_avg:129.93ms | |
step:596/1395 train_time:76146ms step_avg:129.94ms | |
step:597/1395 train_time:76282ms step_avg:129.95ms | |
step:598/1395 train_time:76417ms step_avg:129.96ms | |
step:599/1395 train_time:76550ms step_avg:129.97ms | |
step:600/1395 train_time:76684ms step_avg:129.97ms | |
step:601/1395 train_time:76818ms step_avg:129.98ms | |
step:602/1395 train_time:76950ms step_avg:129.98ms | |
step:603/1395 train_time:77086ms step_avg:129.99ms | |
step:604/1395 train_time:77221ms step_avg:130.00ms | |
step:605/1395 train_time:77356ms step_avg:130.01ms | |
step:606/1395 train_time:77491ms step_avg:130.02ms | |
step:607/1395 train_time:77629ms step_avg:130.03ms | |
step:608/1395 train_time:77764ms step_avg:130.04ms | |
step:609/1395 train_time:77900ms step_avg:130.05ms | |
step:610/1395 train_time:78035ms step_avg:130.06ms | |
step:611/1395 train_time:78169ms step_avg:130.06ms | |
step:612/1395 train_time:78305ms step_avg:130.07ms | |
step:613/1395 train_time:78440ms step_avg:130.08ms | |
step:614/1395 train_time:78574ms step_avg:130.09ms | |
step:615/1395 train_time:78708ms step_avg:130.10ms | |
step:616/1395 train_time:78844ms step_avg:130.11ms | |
step:617/1395 train_time:78978ms step_avg:130.11ms | |
step:618/1395 train_time:79111ms step_avg:130.12ms | |
step:619/1395 train_time:79247ms step_avg:130.13ms | |
step:620/1395 train_time:79383ms step_avg:130.14ms | |
step:621/1395 train_time:79518ms step_avg:130.14ms | |
step:622/1395 train_time:79651ms step_avg:130.15ms | |
step:623/1395 train_time:79788ms step_avg:130.16ms | |
step:624/1395 train_time:79923ms step_avg:130.17ms | |
step:625/1395 train_time:80059ms step_avg:130.18ms | |
step:625/1395 val_loss:3.5768 train_time:80168ms step_avg:130.35ms | |
step:626/1395 train_time:80197ms step_avg:130.19ms | |
step:627/1395 train_time:80338ms step_avg:130.21ms | |
step:628/1395 train_time:80474ms step_avg:130.22ms | |
step:629/1395 train_time:80609ms step_avg:130.22ms | |
step:630/1395 train_time:80745ms step_avg:130.23ms | |
step:631/1395 train_time:80880ms step_avg:130.24ms | |
step:632/1395 train_time:81014ms step_avg:130.25ms | |
step:633/1395 train_time:81149ms step_avg:130.26ms | |
step:634/1395 train_time:81285ms step_avg:130.26ms | |
step:635/1395 train_time:81422ms step_avg:130.27ms | |
step:636/1395 train_time:81559ms step_avg:130.29ms | |
step:637/1395 train_time:81695ms step_avg:130.30ms | |
step:638/1395 train_time:81830ms step_avg:130.30ms | |
step:639/1395 train_time:81965ms step_avg:130.31ms | |
step:640/1395 train_time:82100ms step_avg:130.32ms | |
step:641/1395 train_time:82236ms step_avg:130.33ms | |
step:642/1395 train_time:82373ms step_avg:130.34ms | |
step:643/1395 train_time:82511ms step_avg:130.35ms | |
step:644/1395 train_time:82646ms step_avg:130.36ms | |
step:645/1395 train_time:82781ms step_avg:130.36ms | |
step:646/1395 train_time:82917ms step_avg:130.37ms | |
step:647/1395 train_time:83055ms step_avg:130.38ms | |
step:648/1395 train_time:83191ms step_avg:130.39ms | |
step:649/1395 train_time:83327ms step_avg:130.40ms | |
step:650/1395 train_time:83463ms step_avg:130.41ms | |
step:651/1395 train_time:83599ms step_avg:130.42ms | |
step:652/1395 train_time:83735ms step_avg:130.43ms | |
step:653/1395 train_time:83871ms step_avg:130.44ms | |
step:654/1395 train_time:84006ms step_avg:130.44ms | |
step:655/1395 train_time:84140ms step_avg:130.45ms | |
step:656/1395 train_time:84277ms step_avg:130.46ms | |
step:657/1395 train_time:84414ms step_avg:130.47ms | |
step:658/1395 train_time:84549ms step_avg:130.48ms | |
step:659/1395 train_time:84686ms step_avg:130.49ms | |
step:660/1395 train_time:84822ms step_avg:130.50ms | |
step:661/1395 train_time:84959ms step_avg:130.50ms | |
step:662/1395 train_time:85092ms step_avg:130.51ms | |
step:663/1395 train_time:85228ms step_avg:130.52ms | |
step:664/1395 train_time:85364ms step_avg:130.53ms | |
step:665/1395 train_time:85501ms step_avg:130.54ms | |
step:666/1395 train_time:85637ms step_avg:130.54ms | |
step:667/1395 train_time:85774ms step_avg:130.55ms | |
step:668/1395 train_time:85912ms step_avg:130.56ms | |
step:669/1395 train_time:86048ms step_avg:130.57ms | |
step:670/1395 train_time:86183ms step_avg:130.58ms | |
step:671/1395 train_time:86318ms step_avg:130.59ms | |
step:672/1395 train_time:86455ms step_avg:130.60ms | |
step:673/1395 train_time:86592ms step_avg:130.61ms | |
step:674/1395 train_time:86728ms step_avg:130.61ms | |
step:675/1395 train_time:86865ms step_avg:130.62ms | |
step:676/1395 train_time:87001ms step_avg:130.63ms | |
step:677/1395 train_time:87136ms step_avg:130.64ms | |
step:678/1395 train_time:87272ms step_avg:130.65ms | |
step:679/1395 train_time:87408ms step_avg:130.65ms | |
step:680/1395 train_time:87545ms step_avg:130.66ms | |
step:681/1395 train_time:87680ms step_avg:130.67ms | |
step:682/1395 train_time:87817ms step_avg:130.68ms | |
step:683/1395 train_time:87956ms step_avg:130.69ms | |
step:684/1395 train_time:88091ms step_avg:130.70ms | |
step:685/1395 train_time:88227ms step_avg:130.71ms | |
step:686/1395 train_time:88363ms step_avg:130.71ms | |
step:687/1395 train_time:88500ms step_avg:130.72ms | |
step:688/1395 train_time:88636ms step_avg:130.73ms | |
step:689/1395 train_time:88772ms step_avg:130.74ms | |
step:690/1395 train_time:88908ms step_avg:130.75ms | |
step:691/1395 train_time:89044ms step_avg:130.75ms | |
step:692/1395 train_time:89179ms step_avg:130.76ms | |
step:693/1395 train_time:89315ms step_avg:130.77ms | |
step:694/1395 train_time:89452ms step_avg:130.78ms | |
step:695/1395 train_time:89588ms step_avg:130.79ms | |
step:696/1395 train_time:89723ms step_avg:130.79ms | |
step:697/1395 train_time:89861ms step_avg:130.80ms | |
step:698/1395 train_time:89997ms step_avg:130.81ms | |
step:699/1395 train_time:90134ms step_avg:130.82ms | |
step:700/1395 train_time:90271ms step_avg:130.83ms | |
step:701/1395 train_time:90407ms step_avg:130.83ms | |
step:702/1395 train_time:90542ms step_avg:130.84ms | |
step:703/1395 train_time:90677ms step_avg:130.85ms | |
step:704/1395 train_time:90814ms step_avg:130.86ms | |
step:705/1395 train_time:90952ms step_avg:130.87ms | |
step:706/1395 train_time:91091ms step_avg:130.88ms | |
step:707/1395 train_time:91226ms step_avg:130.88ms | |
step:708/1395 train_time:91361ms step_avg:130.89ms | |
step:709/1395 train_time:91497ms step_avg:130.90ms | |
step:710/1395 train_time:91634ms step_avg:130.91ms | |
step:711/1395 train_time:91771ms step_avg:130.91ms | |
step:712/1395 train_time:91908ms step_avg:130.92ms | |
step:713/1395 train_time:92044ms step_avg:130.93ms | |
step:714/1395 train_time:92179ms step_avg:130.94ms | |
step:715/1395 train_time:92315ms step_avg:130.94ms | |
step:716/1395 train_time:92452ms step_avg:130.95ms | |
step:717/1395 train_time:92587ms step_avg:130.96ms | |
step:718/1395 train_time:92722ms step_avg:130.96ms | |
step:719/1395 train_time:92857ms step_avg:130.97ms | |
step:720/1395 train_time:92994ms step_avg:130.98ms | |
step:721/1395 train_time:93130ms step_avg:130.98ms | |
step:722/1395 train_time:93265ms step_avg:130.99ms | |
step:723/1395 train_time:93400ms step_avg:131.00ms | |
step:724/1395 train_time:93537ms step_avg:131.00ms | |
step:725/1395 train_time:93675ms step_avg:131.01ms | |
step:726/1395 train_time:93815ms step_avg:131.03ms | |
step:727/1395 train_time:93954ms step_avg:131.04ms | |
step:728/1395 train_time:94090ms step_avg:131.04ms | |
step:729/1395 train_time:94226ms step_avg:131.05ms | |
step:730/1395 train_time:94363ms step_avg:131.06ms | |
step:731/1395 train_time:94499ms step_avg:131.07ms | |
step:732/1395 train_time:94635ms step_avg:131.07ms | |
step:733/1395 train_time:94773ms step_avg:131.08ms | |
step:734/1395 train_time:94911ms step_avg:131.09ms | |
step:735/1395 train_time:95047ms step_avg:131.10ms | |
step:736/1395 train_time:95183ms step_avg:131.11ms | |
step:737/1395 train_time:95319ms step_avg:131.11ms | |
step:738/1395 train_time:95458ms step_avg:131.12ms | |
step:739/1395 train_time:95595ms step_avg:131.13ms | |
step:740/1395 train_time:95733ms step_avg:131.14ms | |
step:741/1395 train_time:95872ms step_avg:131.15ms | |
step:742/1395 train_time:96009ms step_avg:131.16ms | |
step:743/1395 train_time:96147ms step_avg:131.17ms | |
step:744/1395 train_time:96286ms step_avg:131.18ms | |
step:745/1395 train_time:96425ms step_avg:131.19ms | |
step:746/1395 train_time:96561ms step_avg:131.20ms | |
step:747/1395 train_time:96698ms step_avg:131.20ms | |
step:748/1395 train_time:96835ms step_avg:131.21ms | |
step:749/1395 train_time:96976ms step_avg:131.23ms | |
step:750/1395 train_time:97112ms step_avg:131.23ms | |
step:750/1395 val_loss:3.5221 train_time:97225ms step_avg:131.38ms | |
step:751/1395 train_time:97256ms step_avg:131.25ms | |
step:752/1395 train_time:97398ms step_avg:131.26ms | |
step:753/1395 train_time:97537ms step_avg:131.27ms | |
step:754/1395 train_time:97673ms step_avg:131.28ms | |
step:755/1395 train_time:97809ms step_avg:131.29ms | |
step:756/1395 train_time:97945ms step_avg:131.29ms | |
step:757/1395 train_time:98086ms step_avg:131.31ms | |
step:758/1395 train_time:98222ms step_avg:131.31ms | |
step:759/1395 train_time:98361ms step_avg:131.32ms | |
step:760/1395 train_time:98500ms step_avg:131.33ms | |
step:761/1395 train_time:98638ms step_avg:131.34ms | |
step:762/1395 train_time:98777ms step_avg:131.35ms | |
step:763/1395 train_time:98913ms step_avg:131.36ms | |
step:764/1395 train_time:99051ms step_avg:131.37ms | |
step:765/1395 train_time:99186ms step_avg:131.37ms | |
step:766/1395 train_time:99324ms step_avg:131.38ms | |
step:767/1395 train_time:99463ms step_avg:131.39ms | |
step:768/1395 train_time:99601ms step_avg:131.40ms | |
step:769/1395 train_time:99738ms step_avg:131.41ms | |
step:770/1395 train_time:99875ms step_avg:131.42ms | |
step:771/1395 train_time:100012ms step_avg:131.42ms | |
step:772/1395 train_time:100147ms step_avg:131.43ms | |
step:773/1395 train_time:100285ms step_avg:131.43ms | |
step:774/1395 train_time:100422ms step_avg:131.44ms | |
step:775/1395 train_time:100560ms step_avg:131.45ms | |
step:776/1395 train_time:100698ms step_avg:131.46ms | |
step:777/1395 train_time:100836ms step_avg:131.47ms | |
step:778/1395 train_time:100973ms step_avg:131.48ms | |
step:779/1395 train_time:101108ms step_avg:131.48ms | |
step:780/1395 train_time:101246ms step_avg:131.49ms | |
step:781/1395 train_time:101383ms step_avg:131.50ms | |
step:782/1395 train_time:101521ms step_avg:131.50ms | |
step:783/1395 train_time:101658ms step_avg:131.51ms | |
step:784/1395 train_time:101795ms step_avg:131.52ms | |
step:785/1395 train_time:101931ms step_avg:131.52ms | |
step:786/1395 train_time:102070ms step_avg:131.53ms | |
step:787/1395 train_time:102207ms step_avg:131.54ms | |
step:788/1395 train_time:102343ms step_avg:131.55ms | |
step:789/1395 train_time:102481ms step_avg:131.55ms | |
step:790/1395 train_time:102619ms step_avg:131.56ms | |
step:791/1395 train_time:102756ms step_avg:131.57ms | |
step:792/1395 train_time:102895ms step_avg:131.58ms | |
step:793/1395 train_time:103031ms step_avg:131.58ms | |
step:794/1395 train_time:103168ms step_avg:131.59ms | |
step:795/1395 train_time:103308ms step_avg:131.60ms | |
step:796/1395 train_time:103445ms step_avg:131.61ms | |
step:797/1395 train_time:103582ms step_avg:131.62ms | |
step:798/1395 train_time:103719ms step_avg:131.62ms | |
step:799/1395 train_time:103858ms step_avg:131.63ms | |
step:800/1395 train_time:103996ms step_avg:131.64ms | |
step:801/1395 train_time:104134ms step_avg:131.65ms | |
step:802/1395 train_time:104275ms step_avg:131.66ms | |
step:803/1395 train_time:104412ms step_avg:131.67ms | |
step:804/1395 train_time:104548ms step_avg:131.67ms | |
step:805/1395 train_time:104686ms step_avg:131.68ms | |
step:806/1395 train_time:104822ms step_avg:131.69ms | |
step:807/1395 train_time:104959ms step_avg:131.69ms | |
step:808/1395 train_time:105096ms step_avg:131.70ms | |
step:809/1395 train_time:105234ms step_avg:131.71ms | |
step:810/1395 train_time:105369ms step_avg:131.71ms | |
step:811/1395 train_time:105505ms step_avg:131.72ms | |
step:812/1395 train_time:105643ms step_avg:131.72ms | |
step:813/1395 train_time:105779ms step_avg:131.73ms | |
step:814/1395 train_time:105916ms step_avg:131.74ms | |
step:815/1395 train_time:106052ms step_avg:131.74ms | |
step:816/1395 train_time:106190ms step_avg:131.75ms | |
step:817/1395 train_time:106327ms step_avg:131.76ms | |
step:818/1395 train_time:106462ms step_avg:131.76ms | |
step:819/1395 train_time:106600ms step_avg:131.77ms | |
step:820/1395 train_time:106738ms step_avg:131.77ms | |
step:821/1395 train_time:106875ms step_avg:131.78ms | |
step:822/1395 train_time:107011ms step_avg:131.79ms | |
step:823/1395 train_time:107147ms step_avg:131.79ms | |
step:824/1395 train_time:107283ms step_avg:131.80ms | |
step:825/1395 train_time:107421ms step_avg:131.81ms | |
step:826/1395 train_time:107560ms step_avg:131.81ms | |
step:827/1395 train_time:107698ms step_avg:131.82ms | |
step:828/1395 train_time:107836ms step_avg:131.83ms | |
step:829/1395 train_time:107977ms step_avg:131.84ms | |
step:830/1395 train_time:108113ms step_avg:131.85ms | |
step:831/1395 train_time:108253ms step_avg:131.86ms | |
step:832/1395 train_time:108392ms step_avg:131.86ms | |
step:833/1395 train_time:108532ms step_avg:131.87ms | |
step:834/1395 train_time:108672ms step_avg:131.88ms | |
step:835/1395 train_time:108810ms step_avg:131.89ms | |
step:836/1395 train_time:108950ms step_avg:131.90ms | |
step:837/1395 train_time:109086ms step_avg:131.91ms | |
step:838/1395 train_time:109225ms step_avg:131.91ms | |
step:839/1395 train_time:109362ms step_avg:131.92ms | |
step:840/1395 train_time:109499ms step_avg:131.93ms | |
step:841/1395 train_time:109637ms step_avg:131.93ms | |
step:842/1395 train_time:109775ms step_avg:131.94ms | |
step:843/1395 train_time:109913ms step_avg:131.95ms | |
step:844/1395 train_time:110050ms step_avg:131.95ms | |
step:845/1395 train_time:110187ms step_avg:131.96ms | |
step:846/1395 train_time:110327ms step_avg:131.97ms | |
step:847/1395 train_time:110466ms step_avg:131.98ms | |
step:848/1395 train_time:110603ms step_avg:131.98ms | |
step:849/1395 train_time:110741ms step_avg:131.99ms | |
step:850/1395 train_time:110879ms step_avg:132.00ms | |
step:851/1395 train_time:111020ms step_avg:132.01ms | |
step:852/1395 train_time:111159ms step_avg:132.02ms | |
step:853/1395 train_time:111298ms step_avg:132.03ms | |
step:854/1395 train_time:111435ms step_avg:132.03ms | |
step:855/1395 train_time:111576ms step_avg:132.04ms | |
step:856/1395 train_time:111712ms step_avg:132.05ms | |
step:857/1395 train_time:111851ms step_avg:132.06ms | |
step:858/1395 train_time:111991ms step_avg:132.06ms | |
step:859/1395 train_time:112130ms step_avg:132.07ms | |
step:860/1395 train_time:112266ms step_avg:132.08ms | |
step:861/1395 train_time:112405ms step_avg:132.09ms | |
step:862/1395 train_time:112543ms step_avg:132.09ms | |
step:863/1395 train_time:112686ms step_avg:132.11ms | |
step:864/1395 train_time:112825ms step_avg:132.11ms | |
step:865/1395 train_time:112961ms step_avg:132.12ms | |
step:866/1395 train_time:113107ms step_avg:132.13ms | |
step:867/1395 train_time:113245ms step_avg:132.14ms | |
step:868/1395 train_time:113381ms step_avg:132.15ms | |
step:869/1395 train_time:113518ms step_avg:132.15ms | |
step:870/1395 train_time:113658ms step_avg:132.16ms | |
step:871/1395 train_time:113796ms step_avg:132.17ms | |
step:872/1395 train_time:113935ms step_avg:132.18ms | |
step:873/1395 train_time:114075ms step_avg:132.18ms | |
step:874/1395 train_time:114214ms step_avg:132.19ms | |
step:875/1395 train_time:114354ms step_avg:132.20ms | |
step:875/1395 val_loss:3.4737 train_time:114464ms step_avg:132.33ms | |
step:876/1395 train_time:114494ms step_avg:132.21ms | |
step:877/1395 train_time:114637ms step_avg:132.22ms | |
step:878/1395 train_time:114775ms step_avg:132.23ms | |
step:879/1395 train_time:114914ms step_avg:132.24ms | |
step:880/1395 train_time:115051ms step_avg:132.24ms | |
step:881/1395 train_time:115187ms step_avg:132.25ms | |
step:882/1395 train_time:115323ms step_avg:132.25ms | |
step:883/1395 train_time:115461ms step_avg:132.26ms | |
step:884/1395 train_time:115601ms step_avg:132.27ms | |
step:885/1395 train_time:115742ms step_avg:132.28ms | |
step:886/1395 train_time:115880ms step_avg:132.28ms | |
step:887/1395 train_time:116020ms step_avg:132.29ms | |
step:888/1395 train_time:116162ms step_avg:132.30ms | |
step:889/1395 train_time:116302ms step_avg:132.31ms | |
step:890/1395 train_time:116438ms step_avg:132.32ms | |
step:891/1395 train_time:116578ms step_avg:132.32ms | |
step:892/1395 train_time:116719ms step_avg:132.34ms | |
step:893/1395 train_time:116856ms step_avg:132.34ms | |
step:894/1395 train_time:116995ms step_avg:132.35ms | |
step:895/1395 train_time:117135ms step_avg:132.36ms | |
step:896/1395 train_time:117273ms step_avg:132.36ms | |
step:897/1395 train_time:117412ms step_avg:132.37ms | |
step:898/1395 train_time:117551ms step_avg:132.38ms | |
step:899/1395 train_time:117690ms step_avg:132.38ms | |
step:900/1395 train_time:117828ms step_avg:132.39ms | |
step:901/1395 train_time:117965ms step_avg:132.40ms | |
step:902/1395 train_time:118100ms step_avg:132.40ms | |
step:903/1395 train_time:118244ms step_avg:132.41ms | |
step:904/1395 train_time:118382ms step_avg:132.42ms | |
step:905/1395 train_time:118518ms step_avg:132.42ms | |
step:906/1395 train_time:118657ms step_avg:132.43ms | |
step:907/1395 train_time:118799ms step_avg:132.44ms | |
step:908/1395 train_time:118936ms step_avg:132.45ms | |
step:909/1395 train_time:119075ms step_avg:132.45ms | |
step:910/1395 train_time:119217ms step_avg:132.46ms | |
step:911/1395 train_time:119355ms step_avg:132.47ms | |
step:912/1395 train_time:119492ms step_avg:132.48ms | |
step:913/1395 train_time:119632ms step_avg:132.48ms | |
step:914/1395 train_time:119771ms step_avg:132.49ms | |
step:915/1395 train_time:119910ms step_avg:132.50ms | |
step:916/1395 train_time:120047ms step_avg:132.50ms | |
step:917/1395 train_time:120184ms step_avg:132.51ms | |
step:918/1395 train_time:120322ms step_avg:132.51ms | |
step:919/1395 train_time:120464ms step_avg:132.52ms | |
step:920/1395 train_time:120602ms step_avg:132.53ms | |
step:921/1395 train_time:120738ms step_avg:132.53ms | |
step:922/1395 train_time:120879ms step_avg:132.54ms | |
step:923/1395 train_time:121016ms step_avg:132.55ms | |
step:924/1395 train_time:121154ms step_avg:132.55ms | |
step:925/1395 train_time:121293ms step_avg:132.56ms | |
step:926/1395 train_time:121431ms step_avg:132.57ms | |
step:927/1395 train_time:121570ms step_avg:132.57ms | |
step:928/1395 train_time:121709ms step_avg:132.58ms | |
step:929/1395 train_time:121847ms step_avg:132.59ms | |
step:930/1395 train_time:121985ms step_avg:132.59ms | |
step:931/1395 train_time:122122ms step_avg:132.60ms | |
step:932/1395 train_time:122259ms step_avg:132.60ms | |
step:933/1395 train_time:122402ms step_avg:132.61ms | |
step:934/1395 train_time:122540ms step_avg:132.62ms | |
step:935/1395 train_time:122685ms step_avg:132.63ms | |
step:936/1395 train_time:122823ms step_avg:132.64ms | |
step:937/1395 train_time:122968ms step_avg:132.65ms | |
step:938/1395 train_time:123107ms step_avg:132.66ms | |
step:939/1395 train_time:123246ms step_avg:132.66ms | |
step:940/1395 train_time:123387ms step_avg:132.67ms | |
step:941/1395 train_time:123524ms step_avg:132.68ms | |
step:942/1395 train_time:123663ms step_avg:132.69ms | |
step:943/1395 train_time:123805ms step_avg:132.70ms | |
step:944/1395 train_time:123950ms step_avg:132.71ms | |
step:945/1395 train_time:124089ms step_avg:132.72ms | |
step:946/1395 train_time:124228ms step_avg:132.72ms | |
step:947/1395 train_time:124370ms step_avg:132.73ms | |
step:948/1395 train_time:124509ms step_avg:132.74ms | |
step:949/1395 train_time:124649ms step_avg:132.75ms | |
step:950/1395 train_time:124787ms step_avg:132.75ms | |
step:951/1395 train_time:124927ms step_avg:132.76ms | |
step:952/1395 train_time:125064ms step_avg:132.76ms | |
step:953/1395 train_time:125205ms step_avg:132.77ms | |
step:954/1395 train_time:125343ms step_avg:132.78ms | |
step:955/1395 train_time:125481ms step_avg:132.78ms | |
step:956/1395 train_time:125623ms step_avg:132.79ms | |
step:957/1395 train_time:125762ms step_avg:132.80ms | |
step:958/1395 train_time:125903ms step_avg:132.81ms | |
step:959/1395 train_time:126046ms step_avg:132.82ms | |
step:960/1395 train_time:126185ms step_avg:132.83ms | |
step:961/1395 train_time:126323ms step_avg:132.83ms | |
step:962/1395 train_time:126462ms step_avg:132.84ms | |
step:963/1395 train_time:126607ms step_avg:132.85ms | |
step:964/1395 train_time:126748ms step_avg:132.86ms | |
step:965/1395 train_time:126887ms step_avg:132.87ms | |
step:966/1395 train_time:127027ms step_avg:132.87ms | |
step:967/1395 train_time:127165ms step_avg:132.88ms | |
step:968/1395 train_time:127303ms step_avg:132.88ms | |
step:969/1395 train_time:127445ms step_avg:132.89ms | |
step:970/1395 train_time:127583ms step_avg:132.90ms | |
step:971/1395 train_time:127721ms step_avg:132.90ms | |
step:972/1395 train_time:127859ms step_avg:132.91ms | |
step:973/1395 train_time:127998ms step_avg:132.92ms | |
step:974/1395 train_time:128141ms step_avg:132.93ms | |
step:975/1395 train_time:128279ms step_avg:132.93ms | |
step:976/1395 train_time:128418ms step_avg:132.94ms | |
step:977/1395 train_time:128557ms step_avg:132.94ms | |
step:978/1395 train_time:128698ms step_avg:132.95ms | |
step:979/1395 train_time:128836ms step_avg:132.96ms | |
step:980/1395 train_time:128975ms step_avg:132.96ms | |
step:981/1395 train_time:129114ms step_avg:132.97ms | |
step:982/1395 train_time:129253ms step_avg:132.98ms | |
step:983/1395 train_time:129391ms step_avg:132.98ms | |
step:984/1395 train_time:129531ms step_avg:132.99ms | |
step:985/1395 train_time:129672ms step_avg:133.00ms | |
step:986/1395 train_time:129815ms step_avg:133.01ms | |
step:987/1395 train_time:129953ms step_avg:133.01ms | |
step:988/1395 train_time:130095ms step_avg:133.02ms | |
step:989/1395 train_time:130235ms step_avg:133.03ms | |
step:990/1395 train_time:130376ms step_avg:133.04ms | |
step:991/1395 train_time:130513ms step_avg:133.04ms | |
step:992/1395 train_time:130657ms step_avg:133.05ms | |
step:993/1395 train_time:130804ms step_avg:133.07ms | |
step:994/1395 train_time:130941ms step_avg:133.07ms | |
step:995/1395 train_time:131079ms step_avg:133.07ms | |
step:996/1395 train_time:131218ms step_avg:133.08ms | |
step:997/1395 train_time:131356ms step_avg:133.09ms | |
step:998/1395 train_time:131493ms step_avg:133.09ms | |
step:999/1395 train_time:131632ms step_avg:133.10ms | |
step:1000/1395 train_time:131773ms step_avg:133.10ms | |
step:1000/1395 val_loss:3.4125 train_time:131886ms step_avg:133.22ms | |
step:1001/1395 train_time:131917ms step_avg:133.12ms | |
step:1002/1395 train_time:132058ms step_avg:133.12ms | |
step:1003/1395 train_time:132200ms step_avg:133.13ms | |
step:1004/1395 train_time:132339ms step_avg:133.14ms | |
step:1005/1395 train_time:132481ms step_avg:133.15ms | |
step:1006/1395 train_time:132621ms step_avg:133.15ms | |
step:1007/1395 train_time:132759ms step_avg:133.16ms | |
step:1008/1395 train_time:132898ms step_avg:133.16ms | |
step:1009/1395 train_time:133041ms step_avg:133.17ms | |
step:1010/1395 train_time:133180ms step_avg:133.18ms | |
step:1011/1395 train_time:133319ms step_avg:133.19ms | |
step:1012/1395 train_time:133456ms step_avg:133.19ms | |
step:1013/1395 train_time:133598ms step_avg:133.20ms | |
step:1014/1395 train_time:133736ms step_avg:133.20ms | |
step:1015/1395 train_time:133874ms step_avg:133.21ms | |
step:1016/1395 train_time:134015ms step_avg:133.22ms | |
step:1017/1395 train_time:134158ms step_avg:133.23ms | |
step:1018/1395 train_time:134299ms step_avg:133.23ms | |
step:1019/1395 train_time:134440ms step_avg:133.24ms | |
step:1020/1395 train_time:134581ms step_avg:133.25ms | |
step:1021/1395 train_time:134719ms step_avg:133.25ms | |
step:1022/1395 train_time:134859ms step_avg:133.26ms | |
step:1023/1395 train_time:135000ms step_avg:133.27ms | |
step:1024/1395 train_time:135139ms step_avg:133.27ms | |
step:1025/1395 train_time:135282ms step_avg:133.28ms | |
step:1026/1395 train_time:135422ms step_avg:133.29ms | |
step:1027/1395 train_time:135560ms step_avg:133.29ms | |
step:1028/1395 train_time:135702ms step_avg:133.30ms | |
step:1029/1395 train_time:135848ms step_avg:133.32ms | |
step:1030/1395 train_time:135989ms step_avg:133.32ms | |
step:1031/1395 train_time:136127ms step_avg:133.33ms | |
step:1032/1395 train_time:136265ms step_avg:133.33ms | |
step:1033/1395 train_time:136405ms step_avg:133.34ms | |
step:1034/1395 train_time:136547ms step_avg:133.35ms | |
step:1035/1395 train_time:136687ms step_avg:133.35ms | |
step:1036/1395 train_time:136826ms step_avg:133.36ms | |
step:1037/1395 train_time:136969ms step_avg:133.37ms | |
step:1038/1395 train_time:137110ms step_avg:133.38ms | |
step:1039/1395 train_time:137249ms step_avg:133.38ms | |
step:1040/1395 train_time:137390ms step_avg:133.39ms | |
step:1041/1395 train_time:137530ms step_avg:133.40ms | |
step:1042/1395 train_time:137669ms step_avg:133.40ms | |
step:1043/1395 train_time:137810ms step_avg:133.41ms | |
step:1044/1395 train_time:137953ms step_avg:133.42ms | |
step:1045/1395 train_time:138094ms step_avg:133.42ms | |
step:1046/1395 train_time:138235ms step_avg:133.43ms | |
step:1047/1395 train_time:138374ms step_avg:133.44ms | |
step:1048/1395 train_time:138513ms step_avg:133.44ms | |
step:1049/1395 train_time:138651ms step_avg:133.45ms | |
step:1050/1395 train_time:138791ms step_avg:133.45ms | |
step:1051/1395 train_time:138934ms step_avg:133.46ms | |
step:1052/1395 train_time:139074ms step_avg:133.47ms | |
step:1053/1395 train_time:139212ms step_avg:133.47ms | |
step:1054/1395 train_time:139350ms step_avg:133.48ms | |
step:1055/1395 train_time:139489ms step_avg:133.48ms | |
step:1056/1395 train_time:139629ms step_avg:133.49ms | |
step:1057/1395 train_time:139767ms step_avg:133.49ms | |
step:1058/1395 train_time:139909ms step_avg:133.50ms | |
step:1059/1395 train_time:140053ms step_avg:133.51ms | |
step:1060/1395 train_time:140195ms step_avg:133.52ms | |
step:1061/1395 train_time:140332ms step_avg:133.52ms | |
step:1062/1395 train_time:140473ms step_avg:133.53ms | |
step:1063/1395 train_time:140612ms step_avg:133.53ms | |
step:1064/1395 train_time:140749ms step_avg:133.54ms | |
step:1065/1395 train_time:140889ms step_avg:133.54ms | |
step:1066/1395 train_time:141030ms step_avg:133.55ms | |
step:1067/1395 train_time:141172ms step_avg:133.56ms | |
step:1068/1395 train_time:141312ms step_avg:133.57ms | |
step:1069/1395 train_time:141455ms step_avg:133.57ms | |
step:1070/1395 train_time:141594ms step_avg:133.58ms | |
step:1071/1395 train_time:141739ms step_avg:133.59ms | |
step:1072/1395 train_time:141880ms step_avg:133.60ms | |
step:1073/1395 train_time:142019ms step_avg:133.60ms | |
step:1074/1395 train_time:142157ms step_avg:133.61ms | |
step:1075/1395 train_time:142303ms step_avg:133.62ms | |
step:1076/1395 train_time:142445ms step_avg:133.63ms | |
step:1077/1395 train_time:142587ms step_avg:133.63ms | |
step:1078/1395 train_time:142730ms step_avg:133.64ms | |
step:1079/1395 train_time:142877ms step_avg:133.65ms | |
step:1080/1395 train_time:143018ms step_avg:133.66ms | |
step:1081/1395 train_time:143158ms step_avg:133.67ms | |
step:1082/1395 train_time:143298ms step_avg:133.67ms | |
step:1083/1395 train_time:143438ms step_avg:133.68ms | |
step:1084/1395 train_time:143582ms step_avg:133.69ms | |
step:1085/1395 train_time:143724ms step_avg:133.70ms | |
step:1086/1395 train_time:143866ms step_avg:133.70ms | |
step:1087/1395 train_time:144007ms step_avg:133.71ms | |
step:1088/1395 train_time:144149ms step_avg:133.72ms | |
step:1089/1395 train_time:144294ms step_avg:133.73ms | |
step:1090/1395 train_time:144438ms step_avg:133.74ms | |
step:1091/1395 train_time:144579ms step_avg:133.75ms | |
step:1092/1395 train_time:144718ms step_avg:133.75ms | |
step:1093/1395 train_time:144858ms step_avg:133.76ms | |
step:1094/1395 train_time:144999ms step_avg:133.76ms | |
step:1095/1395 train_time:145137ms step_avg:133.77ms | |
step:1096/1395 train_time:145279ms step_avg:133.77ms | |
step:1097/1395 train_time:145421ms step_avg:133.78ms | |
step:1098/1395 train_time:145563ms step_avg:133.79ms | |
step:1099/1395 train_time:145708ms step_avg:133.80ms | |
step:1100/1395 train_time:145849ms step_avg:133.81ms | |
step:1101/1395 train_time:145988ms step_avg:133.81ms | |
step:1102/1395 train_time:146131ms step_avg:133.82ms | |
step:1103/1395 train_time:146271ms step_avg:133.82ms | |
step:1104/1395 train_time:146410ms step_avg:133.83ms | |
step:1105/1395 train_time:146553ms step_avg:133.84ms | |
step:1106/1395 train_time:146693ms step_avg:133.84ms | |
step:1107/1395 train_time:146833ms step_avg:133.85ms | |
step:1108/1395 train_time:146977ms step_avg:133.86ms | |
step:1109/1395 train_time:147115ms step_avg:133.86ms | |
step:1110/1395 train_time:147255ms step_avg:133.87ms | |
step:1111/1395 train_time:147394ms step_avg:133.87ms | |
step:1112/1395 train_time:147532ms step_avg:133.88ms | |
step:1113/1395 train_time:147671ms step_avg:133.88ms | |
step:1114/1395 train_time:147813ms step_avg:133.89ms | |
step:1115/1395 train_time:147953ms step_avg:133.89ms | |
step:1116/1395 train_time:148095ms step_avg:133.90ms | |
step:1117/1395 train_time:148236ms step_avg:133.91ms | |
step:1118/1395 train_time:148382ms step_avg:133.92ms | |
step:1119/1395 train_time:148521ms step_avg:133.92ms | |
step:1120/1395 train_time:148660ms step_avg:133.93ms | |
step:1121/1395 train_time:148801ms step_avg:133.93ms | |
step:1122/1395 train_time:148942ms step_avg:133.94ms | |
step:1123/1395 train_time:149081ms step_avg:133.95ms | |
step:1124/1395 train_time:149222ms step_avg:133.95ms | |
step:1125/1395 train_time:149361ms step_avg:133.96ms | |
step:1125/1395 val_loss:3.3635 train_time:149474ms step_avg:134.06ms | |
step:1126/1395 train_time:149504ms step_avg:133.96ms | |
step:1127/1395 train_time:149646ms step_avg:133.97ms | |
step:1128/1395 train_time:149789ms step_avg:133.98ms | |
step:1129/1395 train_time:149933ms step_avg:133.99ms | |
step:1130/1395 train_time:150072ms step_avg:133.99ms | |
step:1131/1395 train_time:150214ms step_avg:134.00ms | |
step:1132/1395 train_time:150353ms step_avg:134.00ms | |
step:1133/1395 train_time:150494ms step_avg:134.01ms | |
step:1134/1395 train_time:150637ms step_avg:134.02ms | |
step:1135/1395 train_time:150776ms step_avg:134.02ms | |
step:1136/1395 train_time:150921ms step_avg:134.03ms | |
step:1137/1395 train_time:151061ms step_avg:134.04ms | |
step:1138/1395 train_time:151205ms step_avg:134.05ms | |
step:1139/1395 train_time:151347ms step_avg:134.05ms | |
step:1140/1395 train_time:151488ms step_avg:134.06ms | |
step:1141/1395 train_time:151628ms step_avg:134.07ms | |
step:1142/1395 train_time:151768ms step_avg:134.07ms | |
step:1143/1395 train_time:151913ms step_avg:134.08ms | |
step:1144/1395 train_time:152054ms step_avg:134.09ms | |
step:1145/1395 train_time:152193ms step_avg:134.09ms | |
step:1146/1395 train_time:152334ms step_avg:134.10ms | |
step:1147/1395 train_time:152475ms step_avg:134.10ms | |
step:1148/1395 train_time:152616ms step_avg:134.11ms | |
step:1149/1395 train_time:152757ms step_avg:134.12ms | |
step:1150/1395 train_time:152899ms step_avg:134.12ms | |
step:1151/1395 train_time:153045ms step_avg:134.13ms | |
step:1152/1395 train_time:153185ms step_avg:134.14ms | |
step:1153/1395 train_time:153331ms step_avg:134.15ms | |
step:1154/1395 train_time:153472ms step_avg:134.15ms | |
step:1155/1395 train_time:153611ms step_avg:134.16ms | |
step:1156/1395 train_time:153762ms step_avg:134.17ms | |
step:1157/1395 train_time:153905ms step_avg:134.18ms | |
step:1158/1395 train_time:154044ms step_avg:134.18ms | |
step:1159/1395 train_time:154184ms step_avg:134.19ms | |
step:1160/1395 train_time:154325ms step_avg:134.20ms | |
step:1161/1395 train_time:154466ms step_avg:134.20ms | |
step:1162/1395 train_time:154607ms step_avg:134.21ms | |
step:1163/1395 train_time:154752ms step_avg:134.22ms | |
step:1164/1395 train_time:154893ms step_avg:134.22ms | |
step:1165/1395 train_time:155032ms step_avg:134.23ms | |
step:1166/1395 train_time:155173ms step_avg:134.23ms | |
step:1167/1395 train_time:155315ms step_avg:134.24ms | |
step:1168/1395 train_time:155456ms step_avg:134.25ms | |
step:1169/1395 train_time:155597ms step_avg:134.25ms | |
step:1170/1395 train_time:155739ms step_avg:134.26ms | |
step:1171/1395 train_time:155879ms step_avg:134.26ms | |
step:1172/1395 train_time:156022ms step_avg:134.27ms | |
step:1173/1395 train_time:156163ms step_avg:134.28ms | |
step:1174/1395 train_time:156315ms step_avg:134.29ms | |
step:1175/1395 train_time:156459ms step_avg:134.30ms | |
step:1176/1395 train_time:156600ms step_avg:134.31ms | |
step:1177/1395 train_time:156749ms step_avg:134.32ms | |
step:1178/1395 train_time:156890ms step_avg:134.32ms | |
step:1179/1395 train_time:157028ms step_avg:134.33ms | |
step:1180/1395 train_time:157175ms step_avg:134.34ms | |
step:1181/1395 train_time:157318ms step_avg:134.35ms | |
step:1182/1395 train_time:157458ms step_avg:134.35ms | |
step:1183/1395 train_time:157602ms step_avg:134.36ms | |
step:1184/1395 train_time:157743ms step_avg:134.36ms | |
step:1185/1395 train_time:157887ms step_avg:134.37ms | |
step:1186/1395 train_time:158028ms step_avg:134.38ms | |
step:1187/1395 train_time:158181ms step_avg:134.39ms | |
step:1188/1395 train_time:158320ms step_avg:134.40ms | |
step:1189/1395 train_time:158466ms step_avg:134.41ms | |
step:1190/1395 train_time:158606ms step_avg:134.41ms | |
step:1191/1395 train_time:158748ms step_avg:134.42ms | |
step:1192/1395 train_time:158888ms step_avg:134.42ms | |
step:1193/1395 train_time:159028ms step_avg:134.43ms | |
step:1194/1395 train_time:159168ms step_avg:134.43ms | |
step:1195/1395 train_time:159313ms step_avg:134.44ms | |
step:1196/1395 train_time:159455ms step_avg:134.45ms | |
step:1197/1395 train_time:159598ms step_avg:134.45ms | |
step:1198/1395 train_time:159748ms step_avg:134.47ms | |
step:1199/1395 train_time:159888ms step_avg:134.47ms | |
step:1200/1395 train_time:160029ms step_avg:134.48ms | |
step:1201/1395 train_time:160168ms step_avg:134.48ms | |
step:1202/1395 train_time:160322ms step_avg:134.50ms | |
step:1203/1395 train_time:160468ms step_avg:134.51ms | |
step:1204/1395 train_time:160614ms step_avg:134.52ms | |
step:1205/1395 train_time:160755ms step_avg:134.52ms | |
step:1206/1395 train_time:160897ms step_avg:134.53ms | |
step:1207/1395 train_time:161039ms step_avg:134.54ms | |
step:1208/1395 train_time:161183ms step_avg:134.54ms | |
step:1209/1395 train_time:161326ms step_avg:134.55ms | |
step:1210/1395 train_time:161470ms step_avg:134.56ms | |
step:1211/1395 train_time:161614ms step_avg:134.57ms | |
step:1212/1395 train_time:161755ms step_avg:134.57ms | |
step:1213/1395 train_time:161896ms step_avg:134.58ms | |
step:1214/1395 train_time:162039ms step_avg:134.58ms | |
step:1215/1395 train_time:162185ms step_avg:134.59ms | |
step:1216/1395 train_time:162324ms step_avg:134.60ms | |
step:1217/1395 train_time:162473ms step_avg:134.61ms | |
step:1218/1395 train_time:162613ms step_avg:134.61ms | |
step:1219/1395 train_time:162751ms step_avg:134.62ms | |
step:1220/1395 train_time:162892ms step_avg:134.62ms | |
step:1221/1395 train_time:163032ms step_avg:134.63ms | |
step:1222/1395 train_time:163172ms step_avg:134.63ms | |
step:1223/1395 train_time:163311ms step_avg:134.63ms | |
step:1224/1395 train_time:163457ms step_avg:134.64ms | |
step:1225/1395 train_time:163599ms step_avg:134.65ms | |
step:1226/1395 train_time:163741ms step_avg:134.66ms | |
step:1227/1395 train_time:163881ms step_avg:134.66ms | |
step:1228/1395 train_time:164022ms step_avg:134.67ms | |
step:1229/1395 train_time:164165ms step_avg:134.67ms | |
step:1230/1395 train_time:164311ms step_avg:134.68ms | |
step:1231/1395 train_time:164454ms step_avg:134.69ms | |
step:1232/1395 train_time:164599ms step_avg:134.70ms | |
step:1233/1395 train_time:164741ms step_avg:134.70ms | |
step:1234/1395 train_time:164881ms step_avg:134.71ms | |
step:1235/1395 train_time:165021ms step_avg:134.71ms | |
step:1236/1395 train_time:165162ms step_avg:134.72ms | |
step:1237/1395 train_time:165303ms step_avg:134.72ms | |
step:1238/1395 train_time:165455ms step_avg:134.74ms | |
step:1239/1395 train_time:165595ms step_avg:134.74ms | |
step:1240/1395 train_time:165737ms step_avg:134.75ms | |
step:1241/1395 train_time:165881ms step_avg:134.75ms | |
step:1242/1395 train_time:166021ms step_avg:134.76ms | |
step:1243/1395 train_time:166166ms step_avg:134.77ms | |
step:1244/1395 train_time:166307ms step_avg:134.77ms | |
step:1245/1395 train_time:166449ms step_avg:134.78ms | |
step:1246/1395 train_time:166588ms step_avg:134.78ms | |
step:1247/1395 train_time:166730ms step_avg:134.79ms | |
step:1248/1395 train_time:166870ms step_avg:134.79ms | |
step:1249/1395 train_time:167010ms step_avg:134.79ms | |
step:1250/1395 train_time:167151ms step_avg:134.80ms | |
step:1250/1395 val_loss:3.3176 train_time:167266ms step_avg:134.89ms | |
step:1251/1395 train_time:167298ms step_avg:134.81ms | |
step:1252/1395 train_time:167441ms step_avg:134.82ms | |
step:1253/1395 train_time:167583ms step_avg:134.82ms | |
step:1254/1395 train_time:167726ms step_avg:134.83ms | |
step:1255/1395 train_time:167880ms step_avg:134.84ms | |
step:1256/1395 train_time:168022ms step_avg:134.85ms | |
step:1257/1395 train_time:168164ms step_avg:134.86ms | |
step:1258/1395 train_time:168308ms step_avg:134.86ms | |
step:1259/1395 train_time:168451ms step_avg:134.87ms | |
step:1260/1395 train_time:168591ms step_avg:134.87ms | |
step:1261/1395 train_time:168733ms step_avg:134.88ms | |
step:1262/1395 train_time:168882ms step_avg:134.89ms | |
step:1263/1395 train_time:169028ms step_avg:134.90ms | |
step:1264/1395 train_time:169168ms step_avg:134.90ms | |
step:1265/1395 train_time:169307ms step_avg:134.91ms | |
step:1266/1395 train_time:169451ms step_avg:134.91ms | |
step:1267/1395 train_time:169593ms step_avg:134.92ms | |
step:1268/1395 train_time:169735ms step_avg:134.92ms | |
step:1269/1395 train_time:169882ms step_avg:134.93ms | |
step:1270/1395 train_time:170024ms step_avg:134.94ms | |
step:1271/1395 train_time:170167ms step_avg:134.95ms | |
step:1272/1395 train_time:170306ms step_avg:134.95ms | |
step:1273/1395 train_time:170446ms step_avg:134.95ms | |
step:1274/1395 train_time:170586ms step_avg:134.96ms | |
step:1275/1395 train_time:170730ms step_avg:134.96ms | |
step:1276/1395 train_time:170870ms step_avg:134.97ms | |
step:1277/1395 train_time:171012ms step_avg:134.97ms | |
step:1278/1395 train_time:171154ms step_avg:134.98ms | |
step:1279/1395 train_time:171296ms step_avg:134.99ms | |
step:1280/1395 train_time:171446ms step_avg:135.00ms | |
step:1281/1395 train_time:171587ms step_avg:135.00ms | |
step:1282/1395 train_time:171727ms step_avg:135.01ms | |
step:1283/1395 train_time:171869ms step_avg:135.01ms | |
step:1284/1395 train_time:172012ms step_avg:135.02ms | |
step:1285/1395 train_time:172153ms step_avg:135.02ms | |
step:1286/1395 train_time:172295ms step_avg:135.03ms | |
step:1287/1395 train_time:172439ms step_avg:135.03ms | |
step:1288/1395 train_time:172580ms step_avg:135.04ms | |
step:1289/1395 train_time:172732ms step_avg:135.05ms | |
step:1290/1395 train_time:172879ms step_avg:135.06ms | |
step:1291/1395 train_time:173025ms step_avg:135.07ms | |
step:1292/1395 train_time:173169ms step_avg:135.08ms | |
step:1293/1395 train_time:173317ms step_avg:135.09ms | |
step:1294/1395 train_time:173458ms step_avg:135.09ms | |
step:1295/1395 train_time:173600ms step_avg:135.10ms | |
step:1296/1395 train_time:173746ms step_avg:135.11ms | |
step:1297/1395 train_time:173890ms step_avg:135.11ms | |
step:1298/1395 train_time:174029ms step_avg:135.12ms | |
step:1299/1395 train_time:174172ms step_avg:135.12ms | |
step:1300/1395 train_time:174312ms step_avg:135.13ms | |
step:1301/1395 train_time:174453ms step_avg:135.13ms | |
step:1302/1395 train_time:174595ms step_avg:135.14ms | |
step:1303/1395 train_time:174741ms step_avg:135.14ms | |
step:1304/1395 train_time:174885ms step_avg:135.15ms | |
step:1305/1395 train_time:175026ms step_avg:135.16ms | |
step:1306/1395 train_time:175169ms step_avg:135.16ms | |
step:1307/1395 train_time:175312ms step_avg:135.17ms | |
step:1308/1395 train_time:175456ms step_avg:135.17ms | |
step:1309/1395 train_time:175600ms step_avg:135.18ms | |
step:1310/1395 train_time:175741ms step_avg:135.19ms | |
step:1311/1395 train_time:175880ms step_avg:135.19ms | |
step:1312/1395 train_time:176023ms step_avg:135.19ms | |
step:1313/1395 train_time:176165ms step_avg:135.20ms | |
step:1314/1395 train_time:176307ms step_avg:135.20ms | |
step:1315/1395 train_time:176451ms step_avg:135.21ms | |
step:1316/1395 train_time:176590ms step_avg:135.21ms | |
step:1317/1395 train_time:176733ms step_avg:135.22ms | |
step:1318/1395 train_time:176880ms step_avg:135.23ms | |
step:1319/1395 train_time:177024ms step_avg:135.24ms | |
step:1320/1395 train_time:177164ms step_avg:135.24ms | |
step:1321/1395 train_time:177304ms step_avg:135.24ms | |
step:1322/1395 train_time:177453ms step_avg:135.25ms | |
step:1323/1395 train_time:177595ms step_avg:135.26ms | |
step:1324/1395 train_time:177739ms step_avg:135.27ms | |
step:1325/1395 train_time:177882ms step_avg:135.27ms | |
step:1326/1395 train_time:178029ms step_avg:135.28ms | |
step:1327/1395 train_time:178171ms step_avg:135.29ms | |
step:1328/1395 train_time:178310ms step_avg:135.29ms | |
step:1329/1395 train_time:178464ms step_avg:135.30ms | |
step:1330/1395 train_time:178610ms step_avg:135.31ms | |
step:1331/1395 train_time:178755ms step_avg:135.32ms | |
step:1332/1395 train_time:178904ms step_avg:135.33ms | |
step:1333/1395 train_time:179049ms step_avg:135.34ms | |
step:1334/1395 train_time:179189ms step_avg:135.34ms | |
step:1335/1395 train_time:179328ms step_avg:135.34ms | |
step:1336/1395 train_time:179478ms step_avg:135.35ms | |
step:1337/1395 train_time:179624ms step_avg:135.36ms | |
step:1338/1395 train_time:179764ms step_avg:135.36ms | |
step:1339/1395 train_time:179907ms step_avg:135.37ms | |
step:1340/1395 train_time:180052ms step_avg:135.38ms | |
step:1341/1395 train_time:180193ms step_avg:135.38ms | |
step:1342/1395 train_time:180336ms step_avg:135.39ms | |
step:1343/1395 train_time:180478ms step_avg:135.39ms | |
step:1344/1395 train_time:180619ms step_avg:135.40ms | |
step:1345/1395 train_time:180761ms step_avg:135.40ms | |
step:1346/1395 train_time:180902ms step_avg:135.41ms | |
step:1347/1395 train_time:181046ms step_avg:135.41ms | |
step:1348/1395 train_time:181187ms step_avg:135.42ms | |
step:1349/1395 train_time:181328ms step_avg:135.42ms | |
step:1350/1395 train_time:181469ms step_avg:135.42ms | |
step:1351/1395 train_time:181612ms step_avg:135.43ms | |
step:1352/1395 train_time:181762ms step_avg:135.44ms | |
step:1353/1395 train_time:181909ms step_avg:135.45ms | |
step:1354/1395 train_time:182053ms step_avg:135.46ms | |
step:1355/1395 train_time:182193ms step_avg:135.46ms | |
step:1356/1395 train_time:182333ms step_avg:135.46ms | |
step:1357/1395 train_time:182477ms step_avg:135.47ms | |
step:1358/1395 train_time:182623ms step_avg:135.48ms | |
step:1359/1395 train_time:182765ms step_avg:135.48ms | |
step:1360/1395 train_time:182910ms step_avg:135.49ms | |
step:1361/1395 train_time:183055ms step_avg:135.50ms | |
step:1362/1395 train_time:183201ms step_avg:135.50ms | |
step:1363/1395 train_time:183352ms step_avg:135.52ms | |
step:1364/1395 train_time:183493ms step_avg:135.52ms | |
step:1365/1395 train_time:183631ms step_avg:135.52ms | |
step:1366/1395 train_time:183773ms step_avg:135.53ms | |
step:1367/1395 train_time:183917ms step_avg:135.53ms | |
step:1368/1395 train_time:184062ms step_avg:135.54ms | |
step:1369/1395 train_time:184209ms step_avg:135.55ms | |
step:1370/1395 train_time:184359ms step_avg:135.56ms | |
step:1371/1395 train_time:184503ms step_avg:135.56ms | |
step:1372/1395 train_time:184651ms step_avg:135.57ms | |
step:1373/1395 train_time:184793ms step_avg:135.58ms | |
step:1374/1395 train_time:184941ms step_avg:135.59ms | |
step:1375/1395 train_time:185084ms step_avg:135.59ms | |
step:1375/1395 val_loss:3.2838 train_time:185196ms step_avg:135.67ms | |
step:1376/1395 train_time:185227ms step_avg:135.60ms | |
step:1377/1395 train_time:185375ms step_avg:135.61ms | |
step:1378/1395 train_time:185516ms step_avg:135.61ms | |
step:1379/1395 train_time:185660ms step_avg:135.62ms | |
step:1380/1395 train_time:185802ms step_avg:135.62ms | |
step:1381/1395 train_time:185949ms step_avg:135.63ms | |
step:1382/1395 train_time:186093ms step_avg:135.64ms | |
step:1383/1395 train_time:186235ms step_avg:135.64ms | |
step:1384/1395 train_time:186383ms step_avg:135.65ms | |
step:1385/1395 train_time:186522ms step_avg:135.65ms | |
step:1386/1395 train_time:186662ms step_avg:135.66ms | |
step:1387/1395 train_time:186805ms step_avg:135.66ms | |
step:1388/1395 train_time:186945ms step_avg:135.66ms | |
step:1389/1395 train_time:187088ms step_avg:135.67ms | |
step:1390/1395 train_time:187228ms step_avg:135.67ms | |
step:1391/1395 train_time:187370ms step_avg:135.68ms | |
step:1392/1395 train_time:187516ms step_avg:135.68ms | |
step:1393/1395 train_time:187658ms step_avg:135.69ms | |
step:1394/1395 train_time:187801ms step_avg:135.69ms | |
step:1395/1395 train_time:187941ms step_avg:135.70ms | |
step:1395/1395 val_loss:3.2793 train_time:188056ms step_avg:135.78ms | |
peak memory allocated: 37619 MiB reserved: 39214 MiB |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[project] | |
name = "modded-nanogpt" | |
version = "0.1.0" | |
description = "Add your description here" | |
readme = "README.md" | |
requires-python = "==3.12.*" | |
dependencies = [ | |
"numpy>=2.1.3", | |
"torch", | |
"pytorch-triton>=3.2.0", | |
"huggingface-hub>=0.26.2", | |
"tqdm>=4.67.0", | |
"pip>=24.3.1", | |
] | |
[tool.uv] | |
environments = [ | |
"sys_platform == 'linux'", | |
] | |
[tool.uv.sources] | |
torch = [ | |
{ url = "https://download.pytorch.org/whl/nightly/cu126/torch-2.7.0.dev20250110%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" }, | |
] | |
pytorch-triton = [ | |
{ index = "pytorch-nightly-cu126" }, | |
] | |
[[tool.uv.index]] | |
name = "pytorch-nightly-cu126" | |
url = "https://download.pytorch.org/whl/nightly/cu126" | |
explicit = true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment