Skip to content

Instantly share code, notes, and snippets.

@lapp0
Created December 23, 2024 22:12
Show Gist options
  • Save lapp0/0400bab79d1c52c526bffb9aeaa79d37 to your computer and use it in GitHub Desktop.
Save lapp0/0400bab79d1c52c526bffb9aeaa79d37 to your computer and use it in GitHub Desktop.
ESM2 Train
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import time
import contextlib
from dataclasses import dataclass
from pathlib import Path
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
# -----------------------------------------------------------------------------
# Muon optimizer
@torch.compile
def zeropower_via_newtonschulz5(G, steps):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(0) > G.size(1):
X = X.T
# Ensure spectral norm is at most 1
X = X / (X.norm() + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(0) > G.size(1):
X = X.T
return X
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer assumes that all parameters passed in are 2D.
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
- We believe it is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven't tested this.
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
self.world_size = int(os.environ['WORLD_SIZE'])
self.rank = int(os.environ['RANK'])
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params = list(params)
assert all(isinstance(p, torch.Tensor) for p in params)
sizes = {p.numel() for p in params}
param_groups = [
{
'params': [p for p in params if p.numel() == size],
'update_buffer': [
torch.empty(size, device='cuda', dtype=torch.bfloat16)
for _ in range(self.world_size)
],
}
for size in sizes
]
super().__init__(param_groups, defaults)
def step(self):
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
nesterov = group['nesterov']
ns_steps = group['ns_steps']
update_buffers = group['update_buffer']
# generate weight updates in distributed fashion
params = group['params']
assert len(params) % self.world_size == 0
handle = None
params_world = None
def update_prev():
if params_world is None:
return
assert handle is not None
handle.wait()
for p_world, g_world in zip(params_world, update_buffers):
p_world.data.add_(
g_world.view_as(p_world),
alpha=-lr * max(1, p_world.size(0) / p_world.size(1)) ** 0.5,
)
for base_i in range(len(params))[::self.world_size]:
p = params[base_i + self.rank]
g = p.grad
assert g is not None
state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
buf.lerp_(g, 1 - momentum)
g = g.lerp_(buf, momentum) if nesterov else buf
g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten()
update_prev()
handle = dist.all_gather(update_buffers, g, async_op=True)
params_world = params[base_i : base_i + self.world_size]
update_prev()
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions
def norm(x):
return F.rms_norm(x, (x.size(-1),))
class CastedLinear(nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features, bias=False)
def forward(self, x):
return F.linear(x, self.weight.to(x.dtype))
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.register_buffer('inv_freq', (1 / base) ** (torch.arange(0, dim, 2) / dim))
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x):
seq_len = x.shape[1]
if seq_len != self.seq_len_cached:
t = torch.arange(seq_len, device=x.device)
freqs = torch.outer(t, self.inv_freq)
self.seq_len_cached = seq_len
self.cos_cached = freqs.cos()
self.sin_cached = freqs.sin()
cos, sin = self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
# apply_rotary_emb(x, cos, sin)
x1, x2 = x.chunk(2, dim=3)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x)
class CausalSelfAttention(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.c_q = CastedLinear(dim, dim)
self.c_k = CastedLinear(dim, dim)
self.c_v = CastedLinear(dim, dim)
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
self.rotary = Rotary(dim // num_heads) # dim // num_heads = head_dim
self.c_proj = CastedLinear(dim, dim)
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
def forward(self, x, vi, block_mask):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "Must use batch size = 1 for FlexAttention"
q = self.c_q(x).view(B, T, self.num_heads, -1)
k = self.c_k(x).view(B, T, self.num_heads, -1)
v = self.c_v(x).view(B, T, self.num_heads, -1)
v = self.lambdas[0] * v + self.lambdas[1] * vi.view_as(v) # @KoszarskyB & @Grad62304977
q, k = norm(q), norm(k) # QK norm @Grad62304977
q, k = self.rotary(q), self.rotary(k)
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, enable_gqa=True, kernel_options = {
"BLOCK_M": 64, "BLOCK_N": 64, # forward
"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32 # backwards
})
y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim):
super().__init__()
self.c_fc = CastedLinear(dim, 4 * dim)
self.c_proj = CastedLinear(4 * dim, dim)
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = CausalSelfAttention(config.model_dim, config.num_heads)
self.mlp = MLP(config.model_dim)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x, vi, x0, block_mask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
x = x + self.attn(norm(x), vi, block_mask)
x = x + self.mlp(norm(x))
return x
class ValueEmbedding(nn.Module):
def __init__(self, config: "ModelConfig"):
super().__init__()
self.embed = nn.ModuleList([
nn.Embedding(config.vocab_size, config.model_dim)
for _ in range(6)
])
def forward(self, inputs) -> "list[torch.Tensor]":
ve = [emb(inputs) for emb in self.embed]
ve += reversed(ve)
return ve
# -----------------------------------------------------------------------------
# The main ESM Bert model
class BERT(nn.Module):
def __init__(self, config: "ModelConfig"):
super().__init__()
self.mask_id = 32
self.bos_id = 33
self.num_layers = config.num_layers
# U-net design by @brendanh0gan
self.num_encoder_layers = config.num_layers // 2 # Half of the layers for encoder
self.num_decoder_layers = config.num_layers - self.num_encoder_layers # Remaining for decoder
# Add learnable skip connection weights for decoder layers
self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
self.embed = nn.Embedding(config.vocab_size, config.model_dim)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning
# U-net structure on token value embeddings by @leloykun
self.value_embeds = ValueEmbedding(config)
self.lm_head = CastedLinear(config.model_dim, config.vocab_size)
self.lm_head.weight.data.zero_() # @Grad62304977
def encoder_pass(self, input_seq: torch.Tensor, sliding_window_size: torch.Tensor):
docs = (input_seq == self.bos_id).cumsum(0)
def doc_mask_mod(b, h, q_idx, kv_idx):
bidirectional_sliding_window_mask = torch.abs(q_idx - kv_idx) < sliding_window_size
doc_mask = docs[q_idx] == docs[kv_idx]
return bidirectional_sliding_window_mask & doc_mask
S = len(input_seq)
block_mask = create_block_mask(
doc_mask_mod, None, None, S, S,
)
x = self.embed(input_seq[None])
x = norm(x) # @Grad62304977
x0 = x
ve = self.value_embeds(input_seq)
ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:]
# Store outputs for U-Net skip connections
skip_connections = []
# Encoder pass - process only the first half of the blocks
for i in range(self.num_encoder_layers):
x = self.blocks[i](x, ve_enc[i], x0, block_mask)
skip_connections.append(x)
# Decoder pass - process the remaining blocks with weighted skip connections
for i in range(self.num_decoder_layers):
x = x + self.skip_weights[i] * skip_connections.pop()
# U-net structure on token value embeddings by @leloykun
x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_mask)
x = norm(x)
logits = self.lm_head(x)
logits = 30 * torch.tanh(logits / 30) # @Grad62304977
logits = logits.float()
return logits
def forward(self, seq, sliding_window_size: torch.Tensor):
# MLM mask/replace constants from https://www.biorxiv.org/content/10.1101/2022.07.20.500902v3.full.pdf
pct_masked = 0.12
pct_replaced = 0.015
# set pct_masked% to <mask>
mlm_mask = self.get_frac_mask(seq, pct_masked)
input_seq = seq.clone().masked_fill(mlm_mask, self.mask_id)
# substitute pct_replaced% with token id between 4 and 30 (inclusive)
sub_mask = self.get_frac_mask(seq, pct_replaced, include=~mlm_mask)
input_seq[sub_mask] = torch.randint(4, 31, (sub_mask.sum(),), dtype=seq.dtype, device=seq.device)
logits = self.encoder_pass(input_seq, sliding_window_size)
return F.cross_entropy(
logits.view(-1, logits.size(-1)),
seq.masked_fill(~mlm_mask, -100).to(dtype=torch.int64).view(-1),
ignore_index=-100
)
def get_frac_mask(self, seq: torch.Tensor, pct: float, include=None):
docs = (seq == self.bos_id).cumsum(0)
valid_tokens_mask = (seq >= 4) & (seq <= 30)
if include is not None:
valid_tokens_mask &= include
random_values = torch.rand_like(docs, dtype=torch.float) * valid_tokens_mask
# Map each token to its doc index, count tokens per doc, and compute how many to mask
_, inv_docs = torch.unique(docs, return_inverse=True)
doc_counts = torch.bincount(inv_docs) # total tokens in each doc
num_to_mask = (doc_counts.float() * pct).ceil().to(torch.int64)
# Rank tokens globally by random value and select num_to_mask
sorted_indices = torch.argsort(random_values, descending=True)
ranks = torch.empty_like(sorted_indices, dtype=torch.int64)
ranks[sorted_indices] = torch.arange(len(seq), device=seq.device)
return ranks < num_to_mask[inv_docs]
# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader
def _peek_data_shard(file: Path):
# only reads the header, returns header data
# header is 256 int32
header = torch.from_file(f"{file}", False, 256, dtype=torch.int32)
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
return int(header[2]) # number of tokens (claimed)
def _load_data_shard(path: Path, num_tokens):
with path.open("rb", buffering=0) as f:
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
f.seek(256 * 4)
nbytes = f.readinto(tokens.numpy())
assert nbytes == 2 * num_tokens, "number of tokens read does not match header?"
return tokens
class DistributedDataLoader:
def __init__(self, filename_pattern, seq_len, process_rank, num_processes):
self.process_rank = process_rank
self.num_processes = num_processes
self.seq_len = seq_len
# glob files that match the pattern
self.files = sorted(Path.cwd().glob(filename_pattern))
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
# load and validate all data shards, count number of tokens in total
self.files_num_tokens = [_peek_data_shard(file) for file in self.files]
assert min(self.files_num_tokens) >= num_processes * seq_len + 1
self.total_num_tokens = sum(self.files_num_tokens)
self.reset()
def reset(self):
self.current_shard = -1
self.advance()
def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + 1) % len(self.files)
self.current_position = self.process_rank * self.seq_len
self.tokens = _load_data_shard(self.files[self.current_shard], self.files_num_tokens[self.current_shard])
def next_batch(self):
batch_size = self.seq_len * self.num_processes
buf = self.tokens[self.current_position:self.current_position+self.seq_len+1]
# host side async is sufficient;
# no performance improvement was observed when introducing a separate stream.
seq = buf.to(device="cuda", dtype=torch.int32, non_blocking=True) # inputs
# advance current position and load next shard if necessary
self.current_position += batch_size
if self.current_position + batch_size + 1 >= len(self.tokens):
self.advance()
return seq
# -----------------------------------------------------------------------------
# int main
@dataclass
class Hyperparameters:
# data hyperparams
input_bin : str = 'data/omgprot50/omgprot50_train_*.bin' # input .bin to train on
input_val_bin : str = 'data/omgprot50/omgprot50_val_*.bin' # input .bin to eval validation loss on
# optimization hyperparams
batch_size : int = 16 # batch size, in sequences, across all devices
sequence_length : int = 32*1024 # sequence length, in tokens
num_iterations : int = 1480 # number of iterations to run
warmup_iters : int = 0
cooldown_iters : int = 600 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule
weight_decay : float = 0
# evaluation and logging hyperparams
val_loss_every : int = 25 # every how many steps to evaluate val loss? 0 for only at the end
val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
@dataclass
class ModelConfig:
# 33 tokens: https://huggingface.co/Synthyra/ESMplusplus_large/blob/main/modeling_esm_plusplus.py#L868-L874
# Depth of the number of layers is typically more important than the depth of the hidden dimension for PLMs
# ESM2-8M has 6 layers, 20 heads, 320 hidden dim: https://huggingface.co/facebook/esm2_t6_8M_UR50D/blob/main/config.json
# ESM2-35M has 12 layers, 20 heads, 480 hidden dim: https://huggingface.co/facebook/esm2_t12_35M_UR50D/blob/main/config.json
# ESM2-150M has 30 layers, 20 heads, 640 hidden dim: https://huggingface.co/facebook/esm2_t30_150M_UR50D/blob/main/config.json
# ESM2-650M has 33 layers, 20 heads, 1280 hidden dim: https://huggingface.co/facebook/esm2_t33_650M_UR50D/blob/main/config.json
vocab_size : int = 34 # normal vocab plus mock BOS
num_layers : int = 12
num_heads : int = 6 # head dim 128 suggested by @Grad62304977
model_dim : int = 768
model_config = ModelConfig()
args = Hyperparameters()
def get_param_count(model):
total_params = 0
for name, param in model.named_parameters():
total_params += param.numel()
return total_params
# set up DDP (distributed data parallel). torchrun sets this env variable
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
assert torch.cuda.is_available()
device = torch.device(f'cuda:{ddp_local_rank}')
torch.cuda.set_device(device)
print(f'using device: {device}')
dist.init_process_group(backend='nccl', device_id=device)
dist.barrier()
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
# begin logging
logfile = None
if master_process:
run_id = uuid.uuid4()
Path('logs').mkdir(exist_ok=True)
# logdir = Path('logs') / f'{run_id}'
# logdir.mkdir()
logfile = Path('logs') / f'{run_id}.txt'
print(logfile.stem)
# create the log file
with logfile.open('w') as f:
# begin the log by printing this file (the Python code)
print(code, file=f)
print('=' * 100, file=f)
def print0(s, logonly=False):
if master_process:
with logfile.open('a') as f:
if not logonly:
print(s)
print(s, file=f)
# log information about the hardware/software environment this is running on
# and print the full `nvidia-smi` to file
print0(f'Running python {sys.version}')
print0(f'Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:')
import subprocess
result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
print0(f'{result.stdout}', logonly=True)
print0('='*100, logonly=True)
# calculate the number of steps to take in the val loop.
assert args.val_tokens % (args.sequence_length * ddp_world_size) == 0
val_steps = args.val_tokens // (args.sequence_length * ddp_world_size)
# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % (ddp_world_size) == 0
train_accumulation_steps = args.batch_size // ddp_world_size
# load tokens
train_loader = DistributedDataLoader(args.input_bin, args.sequence_length, ddp_rank, ddp_world_size)
val_loader = DistributedDataLoader(args.input_val_bin, args.sequence_length, ddp_rank, ddp_world_size)
print0(f"Training DataLoader: total number of tokens: {train_loader.total_num_tokens} across {len(train_loader.files)} files")
print0(f"Validation DataLoader: total number of tokens: {val_loader.total_num_tokens} across {len(val_loader.files)} files")
print0('='*100, logonly=True)
seq_train = train_loader.next_batch()
model = BERT(model_config)
model = model.cuda().bfloat16()
for m in model.modules():
if isinstance(m, CastedLinear):
m.float()
config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model)
# here we wrap model into DDP container
model = DDP(model, device_ids=[ddp_local_rank], broadcast_buffers=False, gradient_as_bucket_view=True)
raw_model = model.module # always contains the "raw" unwrapped model
# init the optimizer(s)
embed_params = [*raw_model.embed.parameters(), *raw_model.value_embeds.parameters()]
optimizer1 = torch.optim.Adam(embed_params, lr=0.6, betas=(0.8, 0.95), fused=True)
optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.008, betas=(0.8, 0.95), fused=True)
params = list(raw_model.blocks.parameters())
matrix_params = [p for p in params if p.ndim == 2]
scalar_params = [p for p in params if p.ndim < 2] + [raw_model.skip_weights]
optimizer3 = Muon(matrix_params, lr=0.05, momentum=0.95)
optimizer4 = torch.optim.Adam(scalar_params, lr=0.04, betas=(0.8, 0.95), fused=True)
optimizers = [optimizer1, optimizer2, optimizer3, optimizer4]
# learning rate decay scheduler (linear warmup and cooldown)
def get_lr(it):
assert it <= args.num_iterations
# 1) linear warmup for warmup_iters steps
if it < args.warmup_iters:
return (it+1) / args.warmup_iters
# 2) constant lr for a while
elif it < args.num_iterations - args.cooldown_iters:
return 1.0
# 3) linear cooldown
else:
decay_ratio = (args.num_iterations - it) / args.cooldown_iters
return decay_ratio
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
sliding_window_size = torch.tensor(128, dtype=torch.int32, device="cuda")
sw_prev = 128
# Start training loop
training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.perf_counter()
# begin training
for step in range(args.num_iterations + 1):
last_step = (step == args.num_iterations)
# This effectively ignores timing first 10 steps, which are slower for weird reasons.
# Alternately, and slightly more correctly in terms of benchmarking, we could do 10
# steps with dummy data first, and then re-initialize the model and reset the loader.
if step == 10:
training_time_ms = 0
t0 = time.perf_counter()
timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
# Linearly increase the sliding window size over training in chunks of 128 from 128 -> 2048. By @fernbear.bsky.social
frac_done = step / args.num_iterations # training progress
sw_size = int(((1 - frac_done) * 128 + frac_done * 2048) // 128) * 128
if sw_size != sw_prev:
sliding_window_size.copy_(sw_size, non_blocking=True)
sw_prev = sw_size
# once in a while evaluate the validation dataset
if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
# stop the clock
torch.cuda.synchronize()
training_time_ms += 1000 * (time.perf_counter() - t0)
# run validation batches
model.eval()
val_loader.reset()
val_loss = 0.0
for _ in range(val_steps):
with torch.no_grad():
seq_val = val_loader.next_batch()
val_loss += model(seq_val, sliding_window_size)
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
val_loss /= val_steps
# log val loss to console and to logfile
print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms perplexity:{(2**val_loss):.4f} param_count:{get_param_count(model):,}')
# start the clock again
torch.cuda.synchronize()
t0 = time.perf_counter()
# uncomment if you want to save any checkpoints
#save_every = 1000
#if master_process and (last_step or (save_every > 0 and step % save_every == 0)):
# # stop the clock
# torch.cuda.synchronize()
# training_time_ms += 1000 * (time.perf_counter() - t0)
# # save the state of the training process
# log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
# torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
# # start the clock again
# torch.cuda.synchronize()
# t0 = time.perf_counter()
# bit confusing: we want to make sure to eval on 0th iteration
# but also after the very last iteration. so we loop for step <= num_iterations
# instead of just < num_iterations (one extra due to <=), only to do
# the validation/sampling one last time, and then we break right here as we're done.
if last_step:
break
# --------------- TRAINING SECTION BEGIN -----------------
model.train()
for i in range(1, train_accumulation_steps + 1):
with contextlib.ExitStack() as stack:
if i < train_accumulation_steps: # there's no need to sync gradients every accumulation step
stack.enter_context(model.no_sync())
if step >= 5:
stack.enter_context(torch.compiler.set_stance(skip_guard_eval_unsafe=True))
model(seq_train, sliding_window_size).backward()
seq_train = train_loader.next_batch()
if train_accumulation_steps != 1:
for p in model.parameters():
p.grad /= train_accumulation_steps
# momentum warmup for Muon
frac = min(step/300, 1)
for group in optimizer3.param_groups:
group['momentum'] = (1 - frac) * 0.85 + frac * 0.95
# step the optimizers and schedulers
for opt, sched in zip(optimizers, schedulers):
opt.step()
sched.step()
# null the gradients
model.zero_grad(set_to_none=True)
# --------------- TRAINING SECTION END -------------------
# everything that follows now is just diagnostics, prints, logging, etc.
approx_time = training_time_ms + 1000 * (time.perf_counter() - t0)
print0(f"step:{step+1}/{args.num_iterations} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()
====================================================================================================
Running python 3.11.10 | packaged by conda-forge | (main, Oct 16 2024, 01:27:36) [GCC 13.3.0]
Running pytorch 2.6.0.dev20241203+cu124 compiled for CUDA 12.4
nvidia-smi:
Mon Dec 23 21:13:38 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 On | 00000000:42:00.0 Off | Off |
| 30% 26C P2 60W / 450W | 1756MiB / 24564MiB | 36% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 4090 On | 00000000:81:00.0 Off | Off |
| 31% 27C P2 54W / 450W | 591MiB / 24564MiB | 7% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA GeForce RTX 4090 On | 00000000:82:00.0 Off | Off |
| 31% 26C P2 66W / 450W | 591MiB / 24564MiB | 3% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA GeForce RTX 4090 On | 00000000:C1:00.0 Off | Off |
| 31% 28C P2 54W / 450W | 591MiB / 24564MiB | 12% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
====================================================================================================
Training DataLoader: total number of tokens: 1000000000 across 10 files
Validation DataLoader: total number of tokens: 100000000 across 1 files
====================================================================================================
step:0/1480 val_loss:3.5264 train_time:0ms step_avg:nanms perplexity:11.5224 param_count:85,143,606
step:1/1480 train_time:575773ms step_avg:nanms
step:2/1480 train_time:585321ms step_avg:nanms
step:3/1480 train_time:586195ms step_avg:nanms
step:4/1480 train_time:587123ms step_avg:nanms
step:5/1480 train_time:588049ms step_avg:nanms
step:6/1480 train_time:588981ms step_avg:nanms
step:7/1480 train_time:589908ms step_avg:nanms
step:8/1480 train_time:590823ms step_avg:nanms
step:9/1480 train_time:591747ms step_avg:nanms
step:10/1480 train_time:592680ms step_avg:nanms
step:11/1480 train_time:936ms step_avg:nanms
step:12/1480 train_time:1865ms step_avg:nanms
step:13/1480 train_time:2786ms step_avg:928.50ms
step:14/1480 train_time:3714ms step_avg:928.40ms
step:15/1480 train_time:4646ms step_avg:929.23ms
step:16/1480 train_time:5580ms step_avg:929.97ms
step:17/1480 train_time:6512ms step_avg:930.29ms
step:18/1480 train_time:7431ms step_avg:928.86ms
step:19/1480 train_time:8357ms step_avg:928.59ms
step:20/1480 train_time:9285ms step_avg:928.51ms
step:21/1480 train_time:10213ms step_avg:928.42ms
step:22/1480 train_time:11143ms step_avg:928.55ms
step:23/1480 train_time:12069ms step_avg:928.42ms
step:24/1480 train_time:12994ms step_avg:928.12ms
step:25/1480 train_time:13919ms step_avg:927.96ms
step:25/1480 val_loss:2.9132 train_time:13963ms step_avg:930.88ms perplexity:7.5329 param_count:85,143,606
step:26/1480 train_time:14846ms step_avg:927.90ms
step:27/1480 train_time:15781ms step_avg:928.28ms
step:28/1480 train_time:16711ms step_avg:928.37ms
step:29/1480 train_time:17643ms step_avg:928.60ms
step:30/1480 train_time:18576ms step_avg:928.79ms
step:31/1480 train_time:19506ms step_avg:928.87ms
step:32/1480 train_time:20434ms step_avg:928.81ms
step:33/1480 train_time:21366ms step_avg:928.95ms
step:34/1480 train_time:22300ms step_avg:929.17ms
step:35/1480 train_time:23229ms step_avg:929.14ms
step:36/1480 train_time:24153ms step_avg:928.98ms
step:37/1480 train_time:25082ms step_avg:928.98ms
step:38/1480 train_time:26019ms step_avg:929.24ms
step:39/1480 train_time:26953ms step_avg:929.42ms
step:40/1480 train_time:27884ms step_avg:929.45ms
step:41/1480 train_time:28802ms step_avg:929.10ms
step:42/1480 train_time:29719ms step_avg:928.71ms
step:43/1480 train_time:30646ms step_avg:928.66ms
step:44/1480 train_time:31578ms step_avg:928.76ms
step:45/1480 train_time:32509ms step_avg:928.82ms
step:46/1480 train_time:33442ms step_avg:928.93ms
step:47/1480 train_time:34370ms step_avg:928.92ms
step:48/1480 train_time:35293ms step_avg:928.75ms
step:49/1480 train_time:36225ms step_avg:928.84ms
step:50/1480 train_time:37162ms step_avg:929.05ms
step:50/1480 val_loss:2.8596 train_time:37207ms step_avg:930.17ms perplexity:7.2582 param_count:85,143,606
step:51/1480 train_time:38090ms step_avg:929.03ms
step:52/1480 train_time:39017ms step_avg:928.98ms
step:53/1480 train_time:39939ms step_avg:928.82ms
step:54/1480 train_time:40861ms step_avg:928.67ms
step:55/1480 train_time:41791ms step_avg:928.68ms
step:56/1480 train_time:42722ms step_avg:928.73ms
step:57/1480 train_time:43653ms step_avg:928.78ms
step:58/1480 train_time:44587ms step_avg:928.90ms
step:59/1480 train_time:45516ms step_avg:928.90ms
step:60/1480 train_time:46443ms step_avg:928.85ms
step:61/1480 train_time:47373ms step_avg:928.89ms
step:62/1480 train_time:48302ms step_avg:928.88ms
step:63/1480 train_time:49237ms step_avg:929.00ms
step:64/1480 train_time:50172ms step_avg:929.11ms
step:65/1480 train_time:51101ms step_avg:929.11ms
step:66/1480 train_time:52034ms step_avg:929.18ms
step:67/1480 train_time:52970ms step_avg:929.30ms
step:68/1480 train_time:53892ms step_avg:929.18ms
step:69/1480 train_time:54826ms step_avg:929.26ms
step:70/1480 train_time:55755ms step_avg:929.26ms
step:71/1480 train_time:56689ms step_avg:929.33ms
step:72/1480 train_time:57621ms step_avg:929.37ms
step:73/1480 train_time:58556ms step_avg:929.45ms
step:74/1480 train_time:59482ms step_avg:929.41ms
step:75/1480 train_time:60405ms step_avg:929.31ms
step:75/1480 val_loss:2.8958 train_time:60450ms step_avg:930.00ms perplexity:7.4428 param_count:85,143,606
step:76/1480 train_time:61334ms step_avg:929.31ms
step:77/1480 train_time:62263ms step_avg:929.30ms
step:78/1480 train_time:63190ms step_avg:929.26ms
step:79/1480 train_time:64117ms step_avg:929.23ms
step:80/1480 train_time:65047ms step_avg:929.25ms
step:81/1480 train_time:65962ms step_avg:929.04ms
step:82/1480 train_time:66878ms step_avg:928.85ms
step:83/1480 train_time:67793ms step_avg:928.67ms
step:84/1480 train_time:68722ms step_avg:928.68ms
step:85/1480 train_time:69653ms step_avg:928.71ms
step:86/1480 train_time:70585ms step_avg:928.75ms
step:87/1480 train_time:71517ms step_avg:928.79ms
step:88/1480 train_time:72450ms step_avg:928.85ms
step:89/1480 train_time:73380ms step_avg:928.86ms
step:90/1480 train_time:74316ms step_avg:928.95ms
step:91/1480 train_time:75243ms step_avg:928.93ms
step:92/1480 train_time:76174ms step_avg:928.95ms
step:93/1480 train_time:77101ms step_avg:928.93ms
step:94/1480 train_time:78030ms step_avg:928.93ms
step:95/1480 train_time:78958ms step_avg:928.92ms
step:96/1480 train_time:79894ms step_avg:929.00ms
step:97/1480 train_time:80824ms step_avg:929.01ms
step:98/1480 train_time:81754ms step_avg:929.02ms
step:99/1480 train_time:82689ms step_avg:929.09ms
step:100/1480 train_time:83644ms step_avg:929.38ms
step:100/1480 val_loss:2.8498 train_time:83689ms step_avg:929.87ms perplexity:7.2092 param_count:85,143,606
step:101/1480 train_time:84597ms step_avg:929.63ms
step:102/1480 train_time:85557ms step_avg:929.97ms
step:103/1480 train_time:86519ms step_avg:930.31ms
step:104/1480 train_time:87471ms step_avg:930.54ms
step:105/1480 train_time:88429ms step_avg:930.83ms
step:106/1480 train_time:89380ms step_avg:931.04ms
step:107/1480 train_time:90334ms step_avg:931.28ms
step:108/1480 train_time:91295ms step_avg:931.59ms
step:109/1480 train_time:92247ms step_avg:931.79ms
step:110/1480 train_time:93204ms step_avg:932.04ms
step:111/1480 train_time:94166ms step_avg:932.33ms
step:112/1480 train_time:95122ms step_avg:932.57ms
step:113/1480 train_time:96082ms step_avg:932.84ms
step:114/1480 train_time:97045ms step_avg:933.12ms
step:115/1480 train_time:98002ms step_avg:933.35ms
step:116/1480 train_time:98958ms step_avg:933.56ms
step:117/1480 train_time:99908ms step_avg:933.72ms
step:118/1480 train_time:100871ms step_avg:933.99ms
step:119/1480 train_time:101836ms step_avg:934.27ms
step:120/1480 train_time:102795ms step_avg:934.50ms
step:121/1480 train_time:103742ms step_avg:934.61ms
step:122/1480 train_time:104700ms step_avg:934.82ms
step:123/1480 train_time:105656ms step_avg:935.01ms
step:124/1480 train_time:106616ms step_avg:935.23ms
step:125/1480 train_time:107568ms step_avg:935.38ms
step:125/1480 val_loss:2.8575 train_time:107614ms step_avg:935.77ms perplexity:7.2475 param_count:85,143,606
step:126/1480 train_time:108523ms step_avg:935.55ms
step:127/1480 train_time:109486ms step_avg:935.78ms
step:128/1480 train_time:110446ms step_avg:935.99ms
step:129/1480 train_time:111401ms step_avg:936.14ms
step:130/1480 train_time:112355ms step_avg:936.29ms
step:131/1480 train_time:113312ms step_avg:936.46ms
step:132/1480 train_time:114270ms step_avg:936.64ms
step:133/1480 train_time:115229ms step_avg:936.83ms
step:134/1480 train_time:116190ms step_avg:937.02ms
step:135/1480 train_time:117137ms step_avg:937.10ms
step:136/1480 train_time:118091ms step_avg:937.23ms
step:137/1480 train_time:119047ms step_avg:937.38ms
step:138/1480 train_time:119999ms step_avg:937.49ms
step:139/1480 train_time:120953ms step_avg:937.62ms
step:140/1480 train_time:121903ms step_avg:937.72ms
step:141/1480 train_time:122860ms step_avg:937.86ms
step:142/1480 train_time:123817ms step_avg:938.01ms
step:143/1480 train_time:124773ms step_avg:938.14ms
step:144/1480 train_time:125730ms step_avg:938.28ms
step:145/1480 train_time:126682ms step_avg:938.39ms
step:146/1480 train_time:127639ms step_avg:938.52ms
step:147/1480 train_time:128600ms step_avg:938.69ms
step:148/1480 train_time:129555ms step_avg:938.81ms
step:149/1480 train_time:130518ms step_avg:938.98ms
step:150/1480 train_time:131478ms step_avg:939.13ms
step:150/1480 val_loss:2.8644 train_time:131523ms step_avg:939.45ms perplexity:7.2823 param_count:85,143,606
step:151/1480 train_time:132434ms step_avg:939.25ms
step:152/1480 train_time:133394ms step_avg:939.40ms
step:153/1480 train_time:134354ms step_avg:939.54ms
step:154/1480 train_time:135300ms step_avg:939.59ms
step:155/1480 train_time:136257ms step_avg:939.70ms
step:156/1480 train_time:137219ms step_avg:939.86ms
step:157/1480 train_time:138178ms step_avg:939.99ms
step:158/1480 train_time:139135ms step_avg:940.10ms
step:159/1480 train_time:140084ms step_avg:940.16ms
step:160/1480 train_time:141046ms step_avg:940.31ms
step:161/1480 train_time:142006ms step_avg:940.43ms
step:162/1480 train_time:142969ms step_avg:940.58ms
step:163/1480 train_time:143933ms step_avg:940.74ms
step:164/1480 train_time:144892ms step_avg:940.86ms
step:165/1480 train_time:145852ms step_avg:940.98ms
step:166/1480 train_time:146806ms step_avg:941.06ms
step:167/1480 train_time:147765ms step_avg:941.18ms
step:168/1480 train_time:148724ms step_avg:941.29ms
step:169/1480 train_time:149683ms step_avg:941.40ms
step:170/1480 train_time:150641ms step_avg:941.51ms
step:171/1480 train_time:151585ms step_avg:941.52ms
step:172/1480 train_time:152532ms step_avg:941.56ms
step:173/1480 train_time:153491ms step_avg:941.67ms
step:174/1480 train_time:154446ms step_avg:941.74ms
step:175/1480 train_time:155405ms step_avg:941.85ms
step:175/1480 val_loss:2.8535 train_time:155446ms step_avg:942.10ms perplexity:7.2273 param_count:85,143,606
step:176/1480 train_time:156357ms step_avg:941.91ms
step:177/1480 train_time:157317ms step_avg:942.02ms
step:178/1480 train_time:158276ms step_avg:942.12ms
step:179/1480 train_time:159230ms step_avg:942.19ms
step:180/1480 train_time:160190ms step_avg:942.29ms
step:181/1480 train_time:161145ms step_avg:942.37ms
step:182/1480 train_time:162102ms step_avg:942.45ms
step:183/1480 train_time:163045ms step_avg:942.46ms
step:184/1480 train_time:163986ms step_avg:942.45ms
step:185/1480 train_time:164926ms step_avg:942.43ms
step:186/1480 train_time:165872ms step_avg:942.46ms
step:187/1480 train_time:166826ms step_avg:942.52ms
step:188/1480 train_time:167780ms step_avg:942.58ms
step:189/1480 train_time:168735ms step_avg:942.65ms
step:190/1480 train_time:169693ms step_avg:942.74ms
step:191/1480 train_time:170648ms step_avg:942.81ms
step:192/1480 train_time:171608ms step_avg:942.90ms
step:193/1480 train_time:172563ms step_avg:942.97ms
step:194/1480 train_time:173512ms step_avg:943.00ms
step:195/1480 train_time:174459ms step_avg:943.02ms
step:196/1480 train_time:175414ms step_avg:943.09ms
step:197/1480 train_time:176378ms step_avg:943.20ms
step:198/1480 train_time:177331ms step_avg:943.25ms
step:199/1480 train_time:178309ms step_avg:943.44ms
step:200/1480 train_time:179272ms step_avg:943.54ms
step:200/1480 val_loss:2.8428 train_time:179315ms step_avg:943.76ms perplexity:7.1742 param_count:85,143,606
step:201/1480 train_time:180245ms step_avg:943.69ms
step:202/1480 train_time:181211ms step_avg:943.81ms
step:203/1480 train_time:182181ms step_avg:943.94ms
step:204/1480 train_time:183158ms step_avg:944.11ms
step:205/1480 train_time:184131ms step_avg:944.26ms
step:206/1480 train_time:185103ms step_avg:944.40ms
step:207/1480 train_time:186076ms step_avg:944.55ms
step:208/1480 train_time:187048ms step_avg:944.69ms
step:209/1480 train_time:188008ms step_avg:944.76ms
step:210/1480 train_time:188980ms step_avg:944.90ms
step:211/1480 train_time:189958ms step_avg:945.06ms
step:212/1480 train_time:190935ms step_avg:945.22ms
step:213/1480 train_time:191904ms step_avg:945.34ms
step:214/1480 train_time:192877ms step_avg:945.47ms
step:215/1480 train_time:193853ms step_avg:945.62ms
step:216/1480 train_time:194829ms step_avg:945.77ms
step:217/1480 train_time:195801ms step_avg:945.90ms
step:218/1480 train_time:196776ms step_avg:946.04ms
step:219/1480 train_time:197749ms step_avg:946.17ms
step:220/1480 train_time:198724ms step_avg:946.30ms
step:221/1480 train_time:199684ms step_avg:946.37ms
step:222/1480 train_time:200653ms step_avg:946.48ms
step:223/1480 train_time:201625ms step_avg:946.59ms
step:224/1480 train_time:202597ms step_avg:946.71ms
step:225/1480 train_time:203570ms step_avg:946.84ms
step:225/1480 val_loss:2.8480 train_time:203615ms step_avg:947.05ms perplexity:7.2000 param_count:85,143,606
step:226/1480 train_time:204543ms step_avg:946.96ms
step:227/1480 train_time:205521ms step_avg:947.10ms
step:228/1480 train_time:206497ms step_avg:947.23ms
step:229/1480 train_time:207470ms step_avg:947.35ms
step:230/1480 train_time:208427ms step_avg:947.39ms
step:231/1480 train_time:209397ms step_avg:947.50ms
step:232/1480 train_time:210371ms step_avg:947.62ms
step:233/1480 train_time:211345ms step_avg:947.73ms
step:234/1480 train_time:212316ms step_avg:947.84ms
step:235/1480 train_time:213287ms step_avg:947.94ms
step:236/1480 train_time:214257ms step_avg:948.04ms
step:237/1480 train_time:215231ms step_avg:948.15ms
step:238/1480 train_time:216207ms step_avg:948.27ms
step:239/1480 train_time:217182ms step_avg:948.39ms
step:240/1480 train_time:218158ms step_avg:948.51ms
step:241/1480 train_time:219136ms step_avg:948.64ms
step:242/1480 train_time:220100ms step_avg:948.71ms
step:243/1480 train_time:221076ms step_avg:948.83ms
step:244/1480 train_time:222056ms step_avg:948.96ms
step:245/1480 train_time:223033ms step_avg:949.08ms
step:246/1480 train_time:224005ms step_avg:949.18ms
step:247/1480 train_time:224980ms step_avg:949.28ms
step:248/1480 train_time:225955ms step_avg:949.39ms
step:249/1480 train_time:226920ms step_avg:949.46ms
step:250/1480 train_time:227889ms step_avg:949.54ms
step:250/1480 val_loss:2.8405 train_time:227934ms step_avg:949.72ms perplexity:7.1625 param_count:85,143,606
step:251/1480 train_time:228847ms step_avg:949.57ms
step:252/1480 train_time:229810ms step_avg:949.63ms
step:253/1480 train_time:230786ms step_avg:949.74ms
step:254/1480 train_time:231762ms step_avg:949.85ms
step:255/1480 train_time:232734ms step_avg:949.94ms
step:256/1480 train_time:233705ms step_avg:950.02ms
step:257/1480 train_time:234682ms step_avg:950.13ms
step:258/1480 train_time:235658ms step_avg:950.24ms
step:259/1480 train_time:236632ms step_avg:950.33ms
step:260/1480 train_time:237611ms step_avg:950.44ms
step:261/1480 train_time:238587ms step_avg:950.55ms
step:262/1480 train_time:239562ms step_avg:950.64ms
step:263/1480 train_time:240527ms step_avg:950.70ms
step:264/1480 train_time:241494ms step_avg:950.77ms
step:265/1480 train_time:242464ms step_avg:950.84ms
step:266/1480 train_time:243436ms step_avg:950.92ms
step:267/1480 train_time:244406ms step_avg:951.00ms
step:268/1480 train_time:245379ms step_avg:951.08ms
step:269/1480 train_time:246343ms step_avg:951.13ms
step:270/1480 train_time:247317ms step_avg:951.22ms
step:271/1480 train_time:248290ms step_avg:951.30ms
step:272/1480 train_time:249266ms step_avg:951.40ms
step:273/1480 train_time:250239ms step_avg:951.48ms
step:274/1480 train_time:251208ms step_avg:951.54ms
step:275/1480 train_time:252169ms step_avg:951.58ms
step:275/1480 val_loss:2.8512 train_time:252209ms step_avg:951.73ms perplexity:7.2158 param_count:85,143,606
step:276/1480 train_time:253136ms step_avg:951.64ms
step:277/1480 train_time:254097ms step_avg:951.68ms
step:278/1480 train_time:255068ms step_avg:951.75ms
step:279/1480 train_time:256046ms step_avg:951.85ms
step:280/1480 train_time:257019ms step_avg:951.92ms
step:281/1480 train_time:257990ms step_avg:951.99ms
step:282/1480 train_time:258962ms step_avg:952.07ms
step:283/1480 train_time:259935ms step_avg:952.14ms
step:284/1480 train_time:260910ms step_avg:952.23ms
step:285/1480 train_time:261881ms step_avg:952.30ms
step:286/1480 train_time:262858ms step_avg:952.39ms
step:287/1480 train_time:263835ms step_avg:952.47ms
step:288/1480 train_time:264806ms step_avg:952.54ms
step:289/1480 train_time:265770ms step_avg:952.58ms
step:290/1480 train_time:266737ms step_avg:952.63ms
step:291/1480 train_time:267706ms step_avg:952.69ms
step:292/1480 train_time:268682ms step_avg:952.77ms
step:293/1480 train_time:269654ms step_avg:952.84ms
step:294/1480 train_time:270617ms step_avg:952.88ms
step:295/1480 train_time:271584ms step_avg:952.93ms
step:296/1480 train_time:272556ms step_avg:952.99ms
step:297/1480 train_time:273530ms step_avg:953.07ms
step:298/1480 train_time:274511ms step_avg:953.16ms
step:299/1480 train_time:275496ms step_avg:953.27ms
step:300/1480 train_time:276479ms step_avg:953.38ms
step:300/1480 val_loss:2.8196 train_time:276525ms step_avg:953.53ms perplexity:7.0598 param_count:85,143,606
step:301/1480 train_time:277460ms step_avg:953.47ms
step:302/1480 train_time:278440ms step_avg:953.56ms
step:303/1480 train_time:279425ms step_avg:953.67ms
step:304/1480 train_time:280406ms step_avg:953.76ms
step:305/1480 train_time:281367ms step_avg:953.79ms
step:306/1480 train_time:282342ms step_avg:953.86ms
step:307/1480 train_time:283326ms step_avg:953.96ms
step:308/1480 train_time:284310ms step_avg:954.06ms
step:309/1480 train_time:285288ms step_avg:954.14ms
step:310/1480 train_time:286276ms step_avg:954.25ms
step:311/1480 train_time:287260ms step_avg:954.35ms
step:312/1480 train_time:288245ms step_avg:954.45ms
step:313/1480 train_time:289219ms step_avg:954.52ms
step:314/1480 train_time:290204ms step_avg:954.62ms
step:315/1480 train_time:291186ms step_avg:954.71ms
step:316/1480 train_time:292173ms step_avg:954.81ms
step:317/1480 train_time:293142ms step_avg:954.86ms
step:318/1480 train_time:294118ms step_avg:954.93ms
step:319/1480 train_time:295101ms step_avg:955.02ms
step:320/1480 train_time:296079ms step_avg:955.09ms
step:321/1480 train_time:297061ms step_avg:955.18ms
step:322/1480 train_time:298043ms step_avg:955.27ms
step:323/1480 train_time:299027ms step_avg:955.36ms
step:324/1480 train_time:300001ms step_avg:955.42ms
step:325/1480 train_time:300977ms step_avg:955.48ms
step:325/1480 val_loss:2.8234 train_time:301022ms step_avg:955.62ms perplexity:7.0784 param_count:85,143,606
step:326/1480 train_time:301944ms step_avg:955.52ms
step:327/1480 train_time:302918ms step_avg:955.58ms
step:328/1480 train_time:303903ms step_avg:955.67ms
step:329/1480 train_time:304886ms step_avg:955.76ms
step:330/1480 train_time:305867ms step_avg:955.83ms
step:331/1480 train_time:306852ms step_avg:955.93ms
step:332/1480 train_time:307830ms step_avg:955.99ms
step:333/1480 train_time:308801ms step_avg:956.04ms
step:334/1480 train_time:309775ms step_avg:956.10ms
step:335/1480 train_time:310754ms step_avg:956.16ms
step:336/1480 train_time:311732ms step_avg:956.23ms
step:337/1480 train_time:312715ms step_avg:956.31ms
step:338/1480 train_time:313703ms step_avg:956.41ms
step:339/1480 train_time:314687ms step_avg:956.50ms
step:340/1480 train_time:315666ms step_avg:956.56ms
step:341/1480 train_time:316643ms step_avg:956.63ms
step:342/1480 train_time:317630ms step_avg:956.72ms
step:343/1480 train_time:318604ms step_avg:956.77ms
step:344/1480 train_time:319590ms step_avg:956.86ms
step:345/1480 train_time:320556ms step_avg:956.88ms
step:346/1480 train_time:321537ms step_avg:956.95ms
step:347/1480 train_time:322522ms step_avg:957.04ms
step:348/1480 train_time:323500ms step_avg:957.10ms
step:349/1480 train_time:324482ms step_avg:957.17ms
step:350/1480 train_time:325468ms step_avg:957.26ms
step:350/1480 val_loss:2.8117 train_time:325509ms step_avg:957.38ms perplexity:7.0209 param_count:85,143,606
step:351/1480 train_time:326445ms step_avg:957.32ms
step:352/1480 train_time:327424ms step_avg:957.38ms
step:353/1480 train_time:328408ms step_avg:957.46ms
step:354/1480 train_time:329395ms step_avg:957.54ms
step:355/1480 train_time:330373ms step_avg:957.60ms
step:356/1480 train_time:331352ms step_avg:957.66ms
step:357/1480 train_time:332331ms step_avg:957.73ms
step:358/1480 train_time:333312ms step_avg:957.79ms
step:359/1480 train_time:334289ms step_avg:957.85ms
step:360/1480 train_time:335274ms step_avg:957.93ms
step:361/1480 train_time:336256ms step_avg:958.00ms
step:362/1480 train_time:337237ms step_avg:958.06ms
step:363/1480 train_time:338221ms step_avg:958.13ms
step:364/1480 train_time:339203ms step_avg:958.20ms
step:365/1480 train_time:340175ms step_avg:958.24ms
step:366/1480 train_time:341153ms step_avg:958.30ms
step:367/1480 train_time:342135ms step_avg:958.36ms
step:368/1480 train_time:343117ms step_avg:958.43ms
step:369/1480 train_time:344099ms step_avg:958.49ms
step:370/1480 train_time:345078ms step_avg:958.55ms
step:371/1480 train_time:346061ms step_avg:958.62ms
step:372/1480 train_time:347044ms step_avg:958.68ms
step:373/1480 train_time:348022ms step_avg:958.74ms
step:374/1480 train_time:349009ms step_avg:958.82ms
step:375/1480 train_time:349995ms step_avg:958.89ms
step:375/1480 val_loss:2.8197 train_time:350040ms step_avg:959.01ms perplexity:7.0604 param_count:85,143,606
step:376/1480 train_time:350981ms step_avg:958.96ms
step:377/1480 train_time:351955ms step_avg:959.01ms
step:378/1480 train_time:352938ms step_avg:959.07ms
step:379/1480 train_time:353912ms step_avg:959.11ms
step:380/1480 train_time:354894ms step_avg:959.17ms
step:381/1480 train_time:355866ms step_avg:959.21ms
step:382/1480 train_time:356846ms step_avg:959.26ms
step:383/1480 train_time:357829ms step_avg:959.33ms
step:384/1480 train_time:358809ms step_avg:959.38ms
step:385/1480 train_time:359792ms step_avg:959.45ms
step:386/1480 train_time:360775ms step_avg:959.51ms
step:387/1480 train_time:361751ms step_avg:959.55ms
step:388/1480 train_time:362732ms step_avg:959.61ms
step:389/1480 train_time:363709ms step_avg:959.65ms
step:390/1480 train_time:364683ms step_avg:959.69ms
step:391/1480 train_time:365665ms step_avg:959.75ms
step:392/1480 train_time:366647ms step_avg:959.81ms
step:393/1480 train_time:367623ms step_avg:959.85ms
step:394/1480 train_time:368600ms step_avg:959.90ms
step:395/1480 train_time:369585ms step_avg:959.96ms
step:396/1480 train_time:370571ms step_avg:960.03ms
step:397/1480 train_time:371557ms step_avg:960.10ms
step:398/1480 train_time:372547ms step_avg:960.17ms
step:399/1480 train_time:373539ms step_avg:960.25ms
step:400/1480 train_time:374512ms step_avg:960.29ms
step:400/1480 val_loss:2.8241 train_time:374552ms step_avg:960.39ms perplexity:7.0815 param_count:85,143,606
step:401/1480 train_time:375493ms step_avg:960.34ms
step:402/1480 train_time:376474ms step_avg:960.39ms
step:403/1480 train_time:377462ms step_avg:960.46ms
step:404/1480 train_time:378442ms step_avg:960.51ms
step:405/1480 train_time:379425ms step_avg:960.57ms
step:406/1480 train_time:380414ms step_avg:960.64ms
step:407/1480 train_time:381400ms step_avg:960.71ms
step:408/1480 train_time:382387ms step_avg:960.77ms
step:409/1480 train_time:383365ms step_avg:960.82ms
step:410/1480 train_time:384348ms step_avg:960.87ms
step:411/1480 train_time:385334ms step_avg:960.93ms
step:412/1480 train_time:386326ms step_avg:961.01ms
step:413/1480 train_time:387314ms step_avg:961.08ms
step:414/1480 train_time:388295ms step_avg:961.13ms
step:415/1480 train_time:389279ms step_avg:961.18ms
step:416/1480 train_time:390259ms step_avg:961.23ms
step:417/1480 train_time:391242ms step_avg:961.28ms
step:418/1480 train_time:392236ms step_avg:961.36ms
step:419/1480 train_time:393221ms step_avg:961.42ms
step:420/1480 train_time:394207ms step_avg:961.48ms
step:421/1480 train_time:395199ms step_avg:961.55ms
step:422/1480 train_time:396190ms step_avg:961.63ms
step:423/1480 train_time:397165ms step_avg:961.66ms
step:424/1480 train_time:398150ms step_avg:961.72ms
step:425/1480 train_time:399136ms step_avg:961.77ms
step:425/1480 val_loss:2.8119 train_time:399180ms step_avg:961.88ms perplexity:7.0219 param_count:85,143,606
step:426/1480 train_time:400115ms step_avg:961.82ms
step:427/1480 train_time:401100ms step_avg:961.87ms
step:428/1480 train_time:402082ms step_avg:961.92ms
step:429/1480 train_time:403072ms step_avg:961.99ms
step:430/1480 train_time:404066ms step_avg:962.06ms
step:431/1480 train_time:405050ms step_avg:962.11ms
step:432/1480 train_time:406030ms step_avg:962.16ms
step:433/1480 train_time:407016ms step_avg:962.21ms
step:434/1480 train_time:408003ms step_avg:962.27ms
step:435/1480 train_time:408991ms step_avg:962.33ms
step:436/1480 train_time:409981ms step_avg:962.40ms
step:437/1480 train_time:410968ms step_avg:962.45ms
step:438/1480 train_time:411950ms step_avg:962.50ms
step:439/1480 train_time:412925ms step_avg:962.53ms
step:440/1480 train_time:413905ms step_avg:962.57ms
step:441/1480 train_time:414897ms step_avg:962.64ms
step:442/1480 train_time:415887ms step_avg:962.70ms
step:443/1480 train_time:416880ms step_avg:962.77ms
step:444/1480 train_time:417862ms step_avg:962.81ms
step:445/1480 train_time:418844ms step_avg:962.86ms
step:446/1480 train_time:419831ms step_avg:962.91ms
step:447/1480 train_time:420801ms step_avg:962.93ms
step:448/1480 train_time:421782ms step_avg:962.97ms
step:449/1480 train_time:422769ms step_avg:963.03ms
step:450/1480 train_time:423755ms step_avg:963.08ms
step:450/1480 val_loss:2.8524 train_time:423800ms step_avg:963.18ms perplexity:7.2222 param_count:85,143,606
step:451/1480 train_time:424742ms step_avg:963.13ms
step:452/1480 train_time:425729ms step_avg:963.19ms
step:453/1480 train_time:426721ms step_avg:963.25ms
step:454/1480 train_time:427720ms step_avg:963.33ms
step:455/1480 train_time:428702ms step_avg:963.38ms
step:456/1480 train_time:429691ms step_avg:963.43ms
step:457/1480 train_time:430675ms step_avg:963.48ms
step:458/1480 train_time:431663ms step_avg:963.53ms
step:459/1480 train_time:432650ms step_avg:963.59ms
step:460/1480 train_time:433641ms step_avg:963.65ms
step:461/1480 train_time:434632ms step_avg:963.71ms
step:462/1480 train_time:435606ms step_avg:963.73ms
step:463/1480 train_time:436595ms step_avg:963.78ms
step:464/1480 train_time:437578ms step_avg:963.83ms
step:465/1480 train_time:438566ms step_avg:963.88ms
step:466/1480 train_time:439544ms step_avg:963.91ms
step:467/1480 train_time:440521ms step_avg:963.94ms
step:468/1480 train_time:441506ms step_avg:963.99ms
step:469/1480 train_time:442482ms step_avg:964.01ms
step:470/1480 train_time:443469ms step_avg:964.06ms
step:471/1480 train_time:444454ms step_avg:964.11ms
step:472/1480 train_time:445442ms step_avg:964.16ms
step:473/1480 train_time:446432ms step_avg:964.22ms
step:474/1480 train_time:447417ms step_avg:964.26ms
step:475/1480 train_time:448402ms step_avg:964.31ms
step:475/1480 val_loss:2.8162 train_time:448448ms step_avg:964.40ms perplexity:7.0429 param_count:85,143,606
step:476/1480 train_time:449390ms step_avg:964.36ms
step:477/1480 train_time:450380ms step_avg:964.41ms
step:478/1480 train_time:451358ms step_avg:964.44ms
step:479/1480 train_time:452346ms step_avg:964.49ms
step:480/1480 train_time:453325ms step_avg:964.52ms
step:481/1480 train_time:454311ms step_avg:964.57ms
step:482/1480 train_time:455302ms step_avg:964.62ms
step:483/1480 train_time:456293ms step_avg:964.68ms
step:484/1480 train_time:457277ms step_avg:964.72ms
step:485/1480 train_time:458260ms step_avg:964.76ms
step:486/1480 train_time:459248ms step_avg:964.81ms
step:487/1480 train_time:460224ms step_avg:964.83ms
step:488/1480 train_time:461208ms step_avg:964.87ms
step:489/1480 train_time:462196ms step_avg:964.92ms
step:490/1480 train_time:463188ms step_avg:964.98ms
step:491/1480 train_time:464181ms step_avg:965.03ms
step:492/1480 train_time:465156ms step_avg:965.05ms
step:493/1480 train_time:466141ms step_avg:965.10ms
step:494/1480 train_time:467132ms step_avg:965.15ms
step:495/1480 train_time:468123ms step_avg:965.20ms
step:496/1480 train_time:469115ms step_avg:965.26ms
step:497/1480 train_time:470112ms step_avg:965.32ms
step:498/1480 train_time:471103ms step_avg:965.38ms
step:499/1480 train_time:472087ms step_avg:965.41ms
step:500/1480 train_time:473075ms step_avg:965.46ms
step:500/1480 val_loss:2.8146 train_time:473119ms step_avg:965.55ms perplexity:7.0350 param_count:85,143,606
step:501/1480 train_time:474053ms step_avg:965.48ms
step:502/1480 train_time:475042ms step_avg:965.53ms
step:503/1480 train_time:476031ms step_avg:965.58ms
step:504/1480 train_time:477019ms step_avg:965.62ms
step:505/1480 train_time:477993ms step_avg:965.64ms
step:506/1480 train_time:478978ms step_avg:965.68ms
step:507/1480 train_time:479971ms step_avg:965.74ms
step:508/1480 train_time:480971ms step_avg:965.81ms
step:509/1480 train_time:481966ms step_avg:965.86ms
step:510/1480 train_time:482950ms step_avg:965.90ms
step:511/1480 train_time:483944ms step_avg:965.96ms
step:512/1480 train_time:484921ms step_avg:965.98ms
step:513/1480 train_time:485912ms step_avg:966.03ms
step:514/1480 train_time:486899ms step_avg:966.07ms
step:515/1480 train_time:487891ms step_avg:966.12ms
step:516/1480 train_time:488890ms step_avg:966.19ms
step:517/1480 train_time:489886ms step_avg:966.24ms
step:518/1480 train_time:490878ms step_avg:966.29ms
step:519/1480 train_time:491855ms step_avg:966.32ms
step:520/1480 train_time:492840ms step_avg:966.35ms
step:521/1480 train_time:493831ms step_avg:966.40ms
step:522/1480 train_time:494821ms step_avg:966.45ms
step:523/1480 train_time:495809ms step_avg:966.49ms
step:524/1480 train_time:496790ms step_avg:966.52ms
step:525/1480 train_time:497780ms step_avg:966.56ms
step:525/1480 val_loss:2.8129 train_time:497826ms step_avg:966.65ms perplexity:7.0270 param_count:85,143,606
step:526/1480 train_time:498767ms step_avg:966.60ms
step:527/1480 train_time:499760ms step_avg:966.65ms
step:528/1480 train_time:500755ms step_avg:966.71ms
step:529/1480 train_time:501744ms step_avg:966.75ms
step:530/1480 train_time:502729ms step_avg:966.79ms
step:531/1480 train_time:503725ms step_avg:966.84ms
step:532/1480 train_time:504714ms step_avg:966.88ms
step:533/1480 train_time:505703ms step_avg:966.93ms
step:534/1480 train_time:506700ms step_avg:966.98ms
step:535/1480 train_time:507682ms step_avg:967.01ms
step:536/1480 train_time:508672ms step_avg:967.06ms
step:537/1480 train_time:509661ms step_avg:967.10ms
step:538/1480 train_time:510643ms step_avg:967.13ms
step:539/1480 train_time:511630ms step_avg:967.16ms
step:540/1480 train_time:512632ms step_avg:967.23ms
step:541/1480 train_time:513619ms step_avg:967.27ms
step:542/1480 train_time:514604ms step_avg:967.30ms
step:543/1480 train_time:515603ms step_avg:967.36ms
step:544/1480 train_time:516598ms step_avg:967.41ms
step:545/1480 train_time:517589ms step_avg:967.46ms
step:546/1480 train_time:518585ms step_avg:967.51ms
step:547/1480 train_time:519574ms step_avg:967.55ms
step:548/1480 train_time:520559ms step_avg:967.58ms
step:549/1480 train_time:521537ms step_avg:967.60ms
step:550/1480 train_time:522517ms step_avg:967.62ms
step:550/1480 val_loss:2.8151 train_time:522557ms step_avg:967.70ms perplexity:7.0375 param_count:85,143,606
step:551/1480 train_time:523501ms step_avg:967.65ms
step:552/1480 train_time:524490ms step_avg:967.69ms
step:553/1480 train_time:525467ms step_avg:967.71ms
step:554/1480 train_time:526452ms step_avg:967.74ms
step:555/1480 train_time:527439ms step_avg:967.78ms
step:556/1480 train_time:528427ms step_avg:967.82ms
step:557/1480 train_time:529421ms step_avg:967.86ms
step:558/1480 train_time:530406ms step_avg:967.89ms
step:559/1480 train_time:531401ms step_avg:967.94ms
step:560/1480 train_time:532387ms step_avg:967.98ms
step:561/1480 train_time:533376ms step_avg:968.01ms
step:562/1480 train_time:534365ms step_avg:968.05ms
step:563/1480 train_time:535356ms step_avg:968.09ms
step:564/1480 train_time:536352ms step_avg:968.14ms
step:565/1480 train_time:537344ms step_avg:968.19ms
step:566/1480 train_time:538345ms step_avg:968.25ms
step:567/1480 train_time:539329ms step_avg:968.28ms
step:568/1480 train_time:540317ms step_avg:968.31ms
step:569/1480 train_time:541304ms step_avg:968.34ms
step:570/1480 train_time:542299ms step_avg:968.39ms
step:571/1480 train_time:543303ms step_avg:968.46ms
step:572/1480 train_time:544287ms step_avg:968.48ms
step:573/1480 train_time:545278ms step_avg:968.52ms
step:574/1480 train_time:546263ms step_avg:968.55ms
step:575/1480 train_time:547249ms step_avg:968.58ms
step:575/1480 val_loss:2.8098 train_time:547293ms step_avg:968.66ms perplexity:7.0117 param_count:85,143,606
step:576/1480 train_time:548227ms step_avg:968.60ms
step:577/1480 train_time:549208ms step_avg:968.62ms
step:578/1480 train_time:550206ms step_avg:968.67ms
step:579/1480 train_time:551198ms step_avg:968.71ms
step:580/1480 train_time:552186ms step_avg:968.75ms
step:581/1480 train_time:553172ms step_avg:968.78ms
step:582/1480 train_time:554161ms step_avg:968.81ms
step:583/1480 train_time:555146ms step_avg:968.84ms
step:584/1480 train_time:556134ms step_avg:968.88ms
step:585/1480 train_time:557128ms step_avg:968.92ms
step:586/1480 train_time:558115ms step_avg:968.95ms
step:587/1480 train_time:559108ms step_avg:968.99ms
step:588/1480 train_time:560104ms step_avg:969.04ms
step:589/1480 train_time:561101ms step_avg:969.09ms
step:590/1480 train_time:562088ms step_avg:969.12ms
step:591/1480 train_time:563075ms step_avg:969.15ms
step:592/1480 train_time:564069ms step_avg:969.19ms
step:593/1480 train_time:565064ms step_avg:969.24ms
step:594/1480 train_time:566058ms step_avg:969.28ms
step:595/1480 train_time:567039ms step_avg:969.30ms
step:596/1480 train_time:568033ms step_avg:969.34ms
step:597/1480 train_time:569025ms step_avg:969.38ms
step:598/1480 train_time:570012ms step_avg:969.41ms
step:599/1480 train_time:571010ms step_avg:969.46ms
step:600/1480 train_time:572004ms step_avg:969.50ms
step:600/1480 val_loss:2.8005 train_time:572045ms step_avg:969.57ms perplexity:6.9670 param_count:85,143,606
step:601/1480 train_time:572987ms step_avg:969.52ms
step:602/1480 train_time:573972ms step_avg:969.55ms
step:603/1480 train_time:574979ms step_avg:969.61ms
step:604/1480 train_time:575982ms step_avg:969.67ms
step:605/1480 train_time:576973ms step_avg:969.70ms
step:606/1480 train_time:577959ms step_avg:969.73ms
step:607/1480 train_time:578958ms step_avg:969.78ms
step:608/1480 train_time:579956ms step_avg:969.83ms
step:609/1480 train_time:580948ms step_avg:969.86ms
step:610/1480 train_time:581949ms step_avg:969.91ms
step:611/1480 train_time:582940ms step_avg:969.95ms
step:612/1480 train_time:583937ms step_avg:970.00ms
step:613/1480 train_time:584927ms step_avg:970.03ms
step:614/1480 train_time:585923ms step_avg:970.07ms
step:615/1480 train_time:586913ms step_avg:970.10ms
step:616/1480 train_time:587905ms step_avg:970.14ms
step:617/1480 train_time:588890ms step_avg:970.16ms
step:618/1480 train_time:589875ms step_avg:970.19ms
step:619/1480 train_time:590861ms step_avg:970.21ms
step:620/1480 train_time:591842ms step_avg:970.23ms
step:621/1480 train_time:592836ms step_avg:970.27ms
step:622/1480 train_time:593821ms step_avg:970.30ms
step:623/1480 train_time:594808ms step_avg:970.32ms
step:624/1480 train_time:595801ms step_avg:970.36ms
step:625/1480 train_time:596794ms step_avg:970.40ms
step:625/1480 val_loss:2.8111 train_time:596834ms step_avg:970.46ms perplexity:7.0181 param_count:85,143,606
step:626/1480 train_time:597777ms step_avg:970.42ms
step:627/1480 train_time:598770ms step_avg:970.45ms
step:628/1480 train_time:599758ms step_avg:970.48ms
step:629/1480 train_time:600746ms step_avg:970.51ms
step:630/1480 train_time:601733ms step_avg:970.54ms
step:631/1480 train_time:602718ms step_avg:970.56ms
step:632/1480 train_time:603711ms step_avg:970.60ms
step:633/1480 train_time:604700ms step_avg:970.63ms
step:634/1480 train_time:605691ms step_avg:970.66ms
step:635/1480 train_time:606686ms step_avg:970.70ms
step:636/1480 train_time:607663ms step_avg:970.71ms
step:637/1480 train_time:608660ms step_avg:970.75ms
step:638/1480 train_time:609642ms step_avg:970.77ms
step:639/1480 train_time:610637ms step_avg:970.81ms
step:640/1480 train_time:611637ms step_avg:970.85ms
step:641/1480 train_time:612629ms step_avg:970.89ms
step:642/1480 train_time:613619ms step_avg:970.92ms
step:643/1480 train_time:614612ms step_avg:970.95ms
step:644/1480 train_time:615606ms step_avg:970.99ms
step:645/1480 train_time:616584ms step_avg:971.00ms
step:646/1480 train_time:617569ms step_avg:971.02ms
step:647/1480 train_time:618563ms step_avg:971.06ms
step:648/1480 train_time:619552ms step_avg:971.08ms
step:649/1480 train_time:620552ms step_avg:971.13ms
step:650/1480 train_time:621549ms step_avg:971.17ms
step:650/1480 val_loss:2.7938 train_time:621594ms step_avg:971.24ms perplexity:6.9347 param_count:85,143,606
step:651/1480 train_time:622535ms step_avg:971.19ms
step:652/1480 train_time:623541ms step_avg:971.25ms
step:653/1480 train_time:624539ms step_avg:971.29ms
step:654/1480 train_time:625531ms step_avg:971.32ms
step:655/1480 train_time:626528ms step_avg:971.36ms
step:656/1480 train_time:627512ms step_avg:971.38ms
step:657/1480 train_time:628500ms step_avg:971.41ms
step:658/1480 train_time:629503ms step_avg:971.45ms
step:659/1480 train_time:630498ms step_avg:971.49ms
step:660/1480 train_time:631487ms step_avg:971.52ms
step:661/1480 train_time:632478ms step_avg:971.55ms
step:662/1480 train_time:633481ms step_avg:971.60ms
step:663/1480 train_time:634493ms step_avg:971.66ms
step:664/1480 train_time:635491ms step_avg:971.70ms
step:665/1480 train_time:636475ms step_avg:971.72ms
step:666/1480 train_time:637469ms step_avg:971.75ms
step:667/1480 train_time:638459ms step_avg:971.78ms
step:668/1480 train_time:639453ms step_avg:971.81ms
step:669/1480 train_time:640445ms step_avg:971.84ms
step:670/1480 train_time:641430ms step_avg:971.86ms
step:671/1480 train_time:642436ms step_avg:971.92ms
step:672/1480 train_time:643434ms step_avg:971.95ms
step:673/1480 train_time:644429ms step_avg:971.99ms
step:674/1480 train_time:645420ms step_avg:972.02ms
step:675/1480 train_time:646415ms step_avg:972.05ms
step:675/1480 val_loss:2.7985 train_time:646461ms step_avg:972.12ms perplexity:6.9574 param_count:85,143,606
step:676/1480 train_time:647399ms step_avg:972.07ms
step:677/1480 train_time:648391ms step_avg:972.10ms
step:678/1480 train_time:649391ms step_avg:972.14ms
step:679/1480 train_time:650380ms step_avg:972.17ms
step:680/1480 train_time:651375ms step_avg:972.20ms
step:681/1480 train_time:652352ms step_avg:972.21ms
step:682/1480 train_time:653345ms step_avg:972.24ms
step:683/1480 train_time:654342ms step_avg:972.28ms
step:684/1480 train_time:655336ms step_avg:972.31ms
step:685/1480 train_time:656328ms step_avg:972.34ms
step:686/1480 train_time:657317ms step_avg:972.36ms
step:687/1480 train_time:658311ms step_avg:972.39ms
step:688/1480 train_time:659292ms step_avg:972.41ms
step:689/1480 train_time:660284ms step_avg:972.44ms
step:690/1480 train_time:661276ms step_avg:972.46ms
step:691/1480 train_time:662274ms step_avg:972.50ms
step:692/1480 train_time:663261ms step_avg:972.52ms
step:693/1480 train_time:664258ms step_avg:972.56ms
step:694/1480 train_time:665259ms step_avg:972.60ms
step:695/1480 train_time:666248ms step_avg:972.62ms
step:696/1480 train_time:667248ms step_avg:972.67ms
step:697/1480 train_time:668237ms step_avg:972.69ms
step:698/1480 train_time:669234ms step_avg:972.72ms
step:699/1480 train_time:670232ms step_avg:972.76ms
step:700/1480 train_time:671223ms step_avg:972.79ms
step:700/1480 val_loss:2.7909 train_time:671267ms step_avg:972.85ms perplexity:6.9205 param_count:85,143,606
step:701/1480 train_time:672223ms step_avg:972.83ms
step:702/1480 train_time:673213ms step_avg:972.85ms
step:703/1480 train_time:674212ms step_avg:972.89ms
step:704/1480 train_time:675201ms step_avg:972.91ms
step:705/1480 train_time:676202ms step_avg:972.95ms
step:706/1480 train_time:677200ms step_avg:972.99ms
step:707/1480 train_time:678187ms step_avg:973.01ms
step:708/1480 train_time:679182ms step_avg:973.04ms
step:709/1480 train_time:680171ms step_avg:973.06ms
step:710/1480 train_time:681165ms step_avg:973.09ms
step:711/1480 train_time:682161ms step_avg:973.13ms
step:712/1480 train_time:683142ms step_avg:973.14ms
step:713/1480 train_time:684133ms step_avg:973.16ms
step:714/1480 train_time:685135ms step_avg:973.20ms
step:715/1480 train_time:686123ms step_avg:973.22ms
step:716/1480 train_time:687123ms step_avg:973.26ms
step:717/1480 train_time:688119ms step_avg:973.29ms
step:718/1480 train_time:689118ms step_avg:973.33ms
step:719/1480 train_time:690102ms step_avg:973.35ms
step:720/1480 train_time:691096ms step_avg:973.37ms
step:721/1480 train_time:692096ms step_avg:973.41ms
step:722/1480 train_time:693094ms step_avg:973.45ms
step:723/1480 train_time:694082ms step_avg:973.47ms
step:724/1480 train_time:695082ms step_avg:973.50ms
step:725/1480 train_time:696074ms step_avg:973.53ms
step:725/1480 val_loss:2.8093 train_time:696118ms step_avg:973.59ms perplexity:7.0095 param_count:85,143,606
step:726/1480 train_time:697101ms step_avg:973.60ms
step:727/1480 train_time:698101ms step_avg:973.64ms
step:728/1480 train_time:699088ms step_avg:973.66ms
step:729/1480 train_time:700082ms step_avg:973.69ms
step:730/1480 train_time:701075ms step_avg:973.72ms
step:731/1480 train_time:702072ms step_avg:973.75ms
step:732/1480 train_time:703062ms step_avg:973.77ms
step:733/1480 train_time:704064ms step_avg:973.81ms
step:734/1480 train_time:705069ms step_avg:973.85ms
step:735/1480 train_time:706063ms step_avg:973.88ms
step:736/1480 train_time:707059ms step_avg:973.91ms
step:737/1480 train_time:708052ms step_avg:973.94ms
step:738/1480 train_time:709052ms step_avg:973.97ms
step:739/1480 train_time:710034ms step_avg:973.98ms
step:740/1480 train_time:711025ms step_avg:974.01ms
step:741/1480 train_time:712021ms step_avg:974.04ms
step:742/1480 train_time:713013ms step_avg:974.06ms
step:743/1480 train_time:714000ms step_avg:974.08ms
step:744/1480 train_time:715003ms step_avg:974.12ms
step:745/1480 train_time:715996ms step_avg:974.14ms
step:746/1480 train_time:716990ms step_avg:974.17ms
step:747/1480 train_time:717976ms step_avg:974.19ms
step:748/1480 train_time:718972ms step_avg:974.22ms
step:749/1480 train_time:719962ms step_avg:974.24ms
step:750/1480 train_time:720964ms step_avg:974.28ms
step:750/1480 val_loss:2.7914 train_time:721007ms step_avg:974.33ms perplexity:6.9228 param_count:85,143,606
step:751/1480 train_time:721957ms step_avg:974.30ms
step:752/1480 train_time:722950ms step_avg:974.33ms
step:753/1480 train_time:723948ms step_avg:974.36ms
step:754/1480 train_time:724944ms step_avg:974.39ms
step:755/1480 train_time:725931ms step_avg:974.40ms
step:756/1480 train_time:726927ms step_avg:974.43ms
step:757/1480 train_time:727917ms step_avg:974.45ms
step:758/1480 train_time:728905ms step_avg:974.47ms
step:759/1480 train_time:729896ms step_avg:974.49ms
step:760/1480 train_time:730902ms step_avg:974.54ms
step:761/1480 train_time:731904ms step_avg:974.57ms
step:762/1480 train_time:732901ms step_avg:974.60ms
step:763/1480 train_time:733895ms step_avg:974.63ms
step:764/1480 train_time:734874ms step_avg:974.63ms
step:765/1480 train_time:735883ms step_avg:974.68ms
step:766/1480 train_time:736886ms step_avg:974.72ms
step:767/1480 train_time:737884ms step_avg:974.75ms
step:768/1480 train_time:738874ms step_avg:974.77ms
step:769/1480 train_time:739860ms step_avg:974.78ms
step:770/1480 train_time:740853ms step_avg:974.81ms
step:771/1480 train_time:741830ms step_avg:974.81ms
step:772/1480 train_time:742810ms step_avg:974.82ms
step:773/1480 train_time:743814ms step_avg:974.85ms
step:774/1480 train_time:744815ms step_avg:974.89ms
step:775/1480 train_time:745807ms step_avg:974.91ms
step:775/1480 val_loss:2.7933 train_time:745848ms step_avg:974.96ms perplexity:6.9321 param_count:85,143,606
step:776/1480 train_time:746789ms step_avg:974.92ms
step:777/1480 train_time:747782ms step_avg:974.94ms
step:778/1480 train_time:748773ms step_avg:974.96ms
step:779/1480 train_time:749756ms step_avg:974.98ms
step:780/1480 train_time:750745ms step_avg:974.99ms
step:781/1480 train_time:751743ms step_avg:975.02ms
step:782/1480 train_time:752738ms step_avg:975.05ms
step:783/1480 train_time:753732ms step_avg:975.07ms
step:784/1480 train_time:754722ms step_avg:975.09ms
step:785/1480 train_time:755724ms step_avg:975.13ms
step:786/1480 train_time:756716ms step_avg:975.15ms
step:787/1480 train_time:757710ms step_avg:975.17ms
step:788/1480 train_time:758703ms step_avg:975.20ms
step:789/1480 train_time:759700ms step_avg:975.22ms
step:790/1480 train_time:760697ms step_avg:975.25ms
step:791/1480 train_time:761694ms step_avg:975.28ms
step:792/1480 train_time:762687ms step_avg:975.30ms
step:793/1480 train_time:763681ms step_avg:975.33ms
step:794/1480 train_time:764671ms step_avg:975.35ms
step:795/1480 train_time:765667ms step_avg:975.37ms
step:796/1480 train_time:766653ms step_avg:975.38ms
step:797/1480 train_time:767658ms step_avg:975.42ms
step:798/1480 train_time:768651ms step_avg:975.45ms
step:799/1480 train_time:769643ms step_avg:975.47ms
step:800/1480 train_time:770642ms step_avg:975.50ms
step:800/1480 val_loss:2.7936 train_time:770687ms step_avg:975.55ms perplexity:6.9335 param_count:85,143,606
step:801/1480 train_time:771631ms step_avg:975.51ms
step:802/1480 train_time:772626ms step_avg:975.54ms
step:803/1480 train_time:773614ms step_avg:975.55ms
step:804/1480 train_time:774616ms step_avg:975.59ms
step:805/1480 train_time:775608ms step_avg:975.61ms
step:806/1480 train_time:776606ms step_avg:975.64ms
step:807/1480 train_time:777597ms step_avg:975.65ms
step:808/1480 train_time:778595ms step_avg:975.68ms
step:809/1480 train_time:779591ms step_avg:975.71ms
step:810/1480 train_time:780576ms step_avg:975.72ms
step:811/1480 train_time:781573ms step_avg:975.75ms
step:812/1480 train_time:782568ms step_avg:975.77ms
step:813/1480 train_time:783561ms step_avg:975.79ms
step:814/1480 train_time:784563ms step_avg:975.82ms
step:815/1480 train_time:785564ms step_avg:975.86ms
step:816/1480 train_time:786562ms step_avg:975.88ms
step:817/1480 train_time:787553ms step_avg:975.90ms
step:818/1480 train_time:788540ms step_avg:975.92ms
step:819/1480 train_time:789533ms step_avg:975.94ms
step:820/1480 train_time:790534ms step_avg:975.97ms
step:821/1480 train_time:791525ms step_avg:975.99ms
step:822/1480 train_time:792528ms step_avg:976.02ms
step:823/1480 train_time:793519ms step_avg:976.04ms
step:824/1480 train_time:794509ms step_avg:976.06ms
step:825/1480 train_time:795499ms step_avg:976.07ms
step:825/1480 val_loss:2.7994 train_time:795545ms step_avg:976.13ms perplexity:6.9613 param_count:85,143,606
step:826/1480 train_time:796500ms step_avg:976.10ms
step:827/1480 train_time:797502ms step_avg:976.13ms
step:828/1480 train_time:798489ms step_avg:976.15ms
step:829/1480 train_time:799482ms step_avg:976.17ms
step:830/1480 train_time:800474ms step_avg:976.19ms
step:831/1480 train_time:801454ms step_avg:976.19ms
step:832/1480 train_time:802450ms step_avg:976.22ms
step:833/1480 train_time:803445ms step_avg:976.24ms
step:834/1480 train_time:804450ms step_avg:976.27ms
step:835/1480 train_time:805443ms step_avg:976.29ms
step:836/1480 train_time:806442ms step_avg:976.32ms
step:837/1480 train_time:807442ms step_avg:976.35ms
step:838/1480 train_time:808428ms step_avg:976.36ms
step:839/1480 train_time:809423ms step_avg:976.39ms
step:840/1480 train_time:810413ms step_avg:976.40ms
step:841/1480 train_time:811404ms step_avg:976.42ms
step:842/1480 train_time:812389ms step_avg:976.43ms
step:843/1480 train_time:813390ms step_avg:976.46ms
step:844/1480 train_time:814388ms step_avg:976.48ms
step:845/1480 train_time:815390ms step_avg:976.51ms
step:846/1480 train_time:816383ms step_avg:976.53ms
step:847/1480 train_time:817378ms step_avg:976.56ms
step:848/1480 train_time:818362ms step_avg:976.57ms
step:849/1480 train_time:819352ms step_avg:976.58ms
step:850/1480 train_time:820345ms step_avg:976.60ms
step:850/1480 val_loss:2.8074 train_time:820390ms step_avg:976.65ms perplexity:7.0000 param_count:85,143,606
step:851/1480 train_time:821336ms step_avg:976.62ms
step:852/1480 train_time:822329ms step_avg:976.64ms
step:853/1480 train_time:823313ms step_avg:976.65ms
step:854/1480 train_time:824309ms step_avg:976.67ms
step:855/1480 train_time:825315ms step_avg:976.70ms
step:856/1480 train_time:826315ms step_avg:976.73ms
step:857/1480 train_time:827308ms step_avg:976.75ms
step:858/1480 train_time:828301ms step_avg:976.77ms
step:859/1480 train_time:829293ms step_avg:976.79ms
step:860/1480 train_time:830284ms step_avg:976.80ms
step:861/1480 train_time:831284ms step_avg:976.83ms
step:862/1480 train_time:832296ms step_avg:976.87ms
step:863/1480 train_time:833300ms step_avg:976.91ms
step:864/1480 train_time:834296ms step_avg:976.93ms
step:865/1480 train_time:835295ms step_avg:976.95ms
step:866/1480 train_time:836287ms step_avg:976.97ms
step:867/1480 train_time:837269ms step_avg:976.98ms
step:868/1480 train_time:838258ms step_avg:976.99ms
step:869/1480 train_time:839253ms step_avg:977.01ms
step:870/1480 train_time:840240ms step_avg:977.02ms
step:871/1480 train_time:841241ms step_avg:977.05ms
step:872/1480 train_time:842234ms step_avg:977.07ms
step:873/1480 train_time:843233ms step_avg:977.10ms
step:874/1480 train_time:844218ms step_avg:977.10ms
step:875/1480 train_time:845206ms step_avg:977.12ms
step:875/1480 val_loss:2.7946 train_time:845250ms step_avg:977.17ms perplexity:6.9382 param_count:85,143,606
step:876/1480 train_time:846197ms step_avg:977.13ms
step:877/1480 train_time:847190ms step_avg:977.15ms
step:878/1480 train_time:848182ms step_avg:977.17ms
step:879/1480 train_time:849176ms step_avg:977.19ms
step:880/1480 train_time:850181ms step_avg:977.22ms
step:881/1480 train_time:851185ms step_avg:977.25ms
step:882/1480 train_time:852199ms step_avg:977.29ms
step:883/1480 train_time:853188ms step_avg:977.31ms
step:884/1480 train_time:854184ms step_avg:977.33ms
step:885/1480 train_time:855165ms step_avg:977.33ms
step:886/1480 train_time:856173ms step_avg:977.37ms
step:887/1480 train_time:857163ms step_avg:977.38ms
step:888/1480 train_time:858146ms step_avg:977.39ms
step:889/1480 train_time:859137ms step_avg:977.40ms
step:890/1480 train_time:860132ms step_avg:977.42ms
step:891/1480 train_time:861127ms step_avg:977.44ms
step:892/1480 train_time:862153ms step_avg:977.50ms
step:893/1480 train_time:863146ms step_avg:977.52ms
step:894/1480 train_time:864144ms step_avg:977.54ms
step:895/1480 train_time:865144ms step_avg:977.56ms
step:896/1480 train_time:866141ms step_avg:977.59ms
step:897/1480 train_time:867156ms step_avg:977.63ms
step:898/1480 train_time:868161ms step_avg:977.66ms
step:899/1480 train_time:869150ms step_avg:977.67ms
step:900/1480 train_time:870149ms step_avg:977.70ms
step:900/1480 val_loss:2.8037 train_time:870191ms step_avg:977.74ms perplexity:6.9824 param_count:85,143,606
step:901/1480 train_time:871134ms step_avg:977.70ms
step:902/1480 train_time:872132ms step_avg:977.73ms
step:903/1480 train_time:873121ms step_avg:977.74ms
step:904/1480 train_time:874132ms step_avg:977.78ms
step:905/1480 train_time:875129ms step_avg:977.80ms
step:906/1480 train_time:876115ms step_avg:977.81ms
step:907/1480 train_time:877112ms step_avg:977.83ms
step:908/1480 train_time:878130ms step_avg:977.87ms
step:909/1480 train_time:879126ms step_avg:977.89ms
step:910/1480 train_time:880122ms step_avg:977.91ms
step:911/1480 train_time:881123ms step_avg:977.94ms
step:912/1480 train_time:882118ms step_avg:977.96ms
step:913/1480 train_time:883108ms step_avg:977.97ms
step:914/1480 train_time:884098ms step_avg:977.98ms
step:915/1480 train_time:885099ms step_avg:978.01ms
step:916/1480 train_time:886089ms step_avg:978.02ms
step:917/1480 train_time:887088ms step_avg:978.05ms
step:918/1480 train_time:888086ms step_avg:978.07ms
step:919/1480 train_time:889081ms step_avg:978.09ms
step:920/1480 train_time:890068ms step_avg:978.10ms
step:921/1480 train_time:891068ms step_avg:978.12ms
step:922/1480 train_time:892069ms step_avg:978.15ms
step:923/1480 train_time:893049ms step_avg:978.15ms
step:924/1480 train_time:894050ms step_avg:978.17ms
step:925/1480 train_time:895050ms step_avg:978.20ms
step:925/1480 val_loss:2.8074 train_time:895095ms step_avg:978.25ms perplexity:7.0002 param_count:85,143,606
step:926/1480 train_time:896052ms step_avg:978.22ms
step:927/1480 train_time:897045ms step_avg:978.24ms
step:928/1480 train_time:898041ms step_avg:978.26ms
step:929/1480 train_time:899031ms step_avg:978.27ms
step:930/1480 train_time:900023ms step_avg:978.29ms
step:931/1480 train_time:901015ms step_avg:978.30ms
step:932/1480 train_time:902015ms step_avg:978.32ms
step:933/1480 train_time:903009ms step_avg:978.34ms
step:934/1480 train_time:904026ms step_avg:978.38ms
step:935/1480 train_time:905023ms step_avg:978.40ms
step:936/1480 train_time:906021ms step_avg:978.42ms
step:937/1480 train_time:907018ms step_avg:978.44ms
step:938/1480 train_time:908007ms step_avg:978.46ms
step:939/1480 train_time:909015ms step_avg:978.49ms
step:940/1480 train_time:910001ms step_avg:978.50ms
step:941/1480 train_time:911008ms step_avg:978.53ms
step:942/1480 train_time:912011ms step_avg:978.55ms
step:943/1480 train_time:913000ms step_avg:978.56ms
step:944/1480 train_time:913988ms step_avg:978.57ms
step:945/1480 train_time:914979ms step_avg:978.59ms
step:946/1480 train_time:915971ms step_avg:978.60ms
step:947/1480 train_time:916967ms step_avg:978.62ms
step:948/1480 train_time:917956ms step_avg:978.63ms
step:949/1480 train_time:918963ms step_avg:978.66ms
step:950/1480 train_time:919952ms step_avg:978.67ms
step:950/1480 val_loss:2.7909 train_time:919998ms step_avg:978.72ms perplexity:6.9206 param_count:85,143,606
step:951/1480 train_time:920941ms step_avg:978.68ms
step:952/1480 train_time:921936ms step_avg:978.70ms
step:953/1480 train_time:922940ms step_avg:978.73ms
step:954/1480 train_time:923928ms step_avg:978.74ms
step:955/1480 train_time:924928ms step_avg:978.76ms
step:956/1480 train_time:925924ms step_avg:978.78ms
step:957/1480 train_time:926919ms step_avg:978.80ms
step:958/1480 train_time:927911ms step_avg:978.81ms
step:959/1480 train_time:928895ms step_avg:978.81ms
step:960/1480 train_time:929890ms step_avg:978.83ms
step:961/1480 train_time:930899ms step_avg:978.86ms
step:962/1480 train_time:931894ms step_avg:978.88ms
step:963/1480 train_time:932897ms step_avg:978.91ms
step:964/1480 train_time:933891ms step_avg:978.92ms
step:965/1480 train_time:934891ms step_avg:978.94ms
step:966/1480 train_time:935886ms step_avg:978.96ms
step:967/1480 train_time:936888ms step_avg:978.98ms
step:968/1480 train_time:937884ms step_avg:979.00ms
step:969/1480 train_time:938883ms step_avg:979.02ms
step:970/1480 train_time:939879ms step_avg:979.04ms
step:971/1480 train_time:940874ms step_avg:979.06ms
step:972/1480 train_time:941866ms step_avg:979.07ms
step:973/1480 train_time:942858ms step_avg:979.08ms
step:974/1480 train_time:943852ms step_avg:979.10ms
step:975/1480 train_time:944851ms step_avg:979.12ms
step:975/1480 val_loss:2.7823 train_time:944896ms step_avg:979.17ms perplexity:6.8797 param_count:85,143,606
step:976/1480 train_time:945842ms step_avg:979.13ms
step:977/1480 train_time:946836ms step_avg:979.15ms
step:978/1480 train_time:947848ms step_avg:979.18ms
step:979/1480 train_time:948852ms step_avg:979.21ms
step:980/1480 train_time:949848ms step_avg:979.23ms
step:981/1480 train_time:950832ms step_avg:979.23ms
step:982/1480 train_time:951827ms step_avg:979.25ms
step:983/1480 train_time:952824ms step_avg:979.26ms
step:984/1480 train_time:953822ms step_avg:979.28ms
step:985/1480 train_time:954817ms step_avg:979.30ms
step:986/1480 train_time:955809ms step_avg:979.31ms
step:987/1480 train_time:956805ms step_avg:979.33ms
step:988/1480 train_time:957801ms step_avg:979.35ms
step:989/1480 train_time:958795ms step_avg:979.36ms
step:990/1480 train_time:959795ms step_avg:979.38ms
step:991/1480 train_time:960787ms step_avg:979.40ms
step:992/1480 train_time:961784ms step_avg:979.41ms
step:993/1480 train_time:962795ms step_avg:979.45ms
step:994/1480 train_time:963813ms step_avg:979.48ms
step:995/1480 train_time:964810ms step_avg:979.50ms
step:996/1480 train_time:965832ms step_avg:979.55ms
step:997/1480 train_time:966829ms step_avg:979.56ms
step:998/1480 train_time:967836ms step_avg:979.59ms
step:999/1480 train_time:968835ms step_avg:979.61ms
step:1000/1480 train_time:969834ms step_avg:979.63ms
step:1000/1480 val_loss:2.7799 train_time:969879ms step_avg:979.68ms perplexity:6.8682 param_count:85,143,606
step:1001/1480 train_time:970815ms step_avg:979.63ms
step:1002/1480 train_time:971810ms step_avg:979.65ms
step:1003/1480 train_time:972807ms step_avg:979.66ms
step:1004/1480 train_time:973819ms step_avg:979.70ms
step:1005/1480 train_time:974817ms step_avg:979.72ms
step:1006/1480 train_time:975821ms step_avg:979.74ms
step:1007/1480 train_time:976840ms step_avg:979.78ms
step:1008/1480 train_time:977832ms step_avg:979.79ms
step:1009/1480 train_time:978835ms step_avg:979.81ms
step:1010/1480 train_time:979826ms step_avg:979.83ms
step:1011/1480 train_time:980822ms step_avg:979.84ms
step:1012/1480 train_time:981826ms step_avg:979.87ms
step:1013/1480 train_time:982847ms step_avg:979.91ms
step:1014/1480 train_time:983846ms step_avg:979.93ms
step:1015/1480 train_time:984829ms step_avg:979.93ms
step:1016/1480 train_time:985831ms step_avg:979.95ms
step:1017/1480 train_time:986857ms step_avg:980.00ms
step:1018/1480 train_time:987852ms step_avg:980.01ms
step:1019/1480 train_time:988859ms step_avg:980.04ms
step:1020/1480 train_time:989874ms step_avg:980.07ms
step:1021/1480 train_time:990867ms step_avg:980.09ms
step:1022/1480 train_time:991862ms step_avg:980.10ms
step:1023/1480 train_time:992852ms step_avg:980.11ms
step:1024/1480 train_time:993860ms step_avg:980.14ms
step:1025/1480 train_time:994860ms step_avg:980.16ms
step:1025/1480 val_loss:2.7853 train_time:994905ms step_avg:980.20ms perplexity:6.8939 param_count:85,143,606
step:1026/1480 train_time:995855ms step_avg:980.17ms
step:1027/1480 train_time:996847ms step_avg:980.18ms
step:1028/1480 train_time:997844ms step_avg:980.20ms
step:1029/1480 train_time:998840ms step_avg:980.22ms
step:1030/1480 train_time:999838ms step_avg:980.23ms
step:1031/1480 train_time:1000842ms step_avg:980.26ms
step:1032/1480 train_time:1001838ms step_avg:980.27ms
step:1033/1480 train_time:1002823ms step_avg:980.28ms
step:1034/1480 train_time:1003815ms step_avg:980.29ms
step:1035/1480 train_time:1004809ms step_avg:980.30ms
step:1036/1480 train_time:1005794ms step_avg:980.31ms
step:1037/1480 train_time:1006792ms step_avg:980.32ms
step:1038/1480 train_time:1007789ms step_avg:980.34ms
step:1039/1480 train_time:1008795ms step_avg:980.36ms
step:1040/1480 train_time:1009780ms step_avg:980.37ms
step:1041/1480 train_time:1010779ms step_avg:980.39ms
step:1042/1480 train_time:1011775ms step_avg:980.40ms
step:1043/1480 train_time:1012797ms step_avg:980.44ms
step:1044/1480 train_time:1013792ms step_avg:980.46ms
step:1045/1480 train_time:1014797ms step_avg:980.48ms
step:1046/1480 train_time:1015814ms step_avg:980.52ms
step:1047/1480 train_time:1016802ms step_avg:980.52ms
step:1048/1480 train_time:1017811ms step_avg:980.55ms
step:1049/1480 train_time:1018804ms step_avg:980.56ms
step:1050/1480 train_time:1019796ms step_avg:980.57ms
step:1050/1480 val_loss:2.7843 train_time:1019838ms step_avg:980.61ms perplexity:6.8889 param_count:85,143,606
step:1051/1480 train_time:1020774ms step_avg:980.57ms
step:1052/1480 train_time:1021767ms step_avg:980.58ms
step:1053/1480 train_time:1022770ms step_avg:980.60ms
step:1054/1480 train_time:1023783ms step_avg:980.63ms
step:1055/1480 train_time:1024780ms step_avg:980.65ms
step:1056/1480 train_time:1025774ms step_avg:980.66ms
step:1057/1480 train_time:1026772ms step_avg:980.68ms
step:1058/1480 train_time:1027785ms step_avg:980.71ms
step:1059/1480 train_time:1028775ms step_avg:980.72ms
step:1060/1480 train_time:1029765ms step_avg:980.73ms
step:1061/1480 train_time:1030761ms step_avg:980.74ms
step:1062/1480 train_time:1031764ms step_avg:980.76ms
step:1063/1480 train_time:1032768ms step_avg:980.79ms
step:1064/1480 train_time:1033760ms step_avg:980.80ms
step:1065/1480 train_time:1034755ms step_avg:980.81ms
step:1066/1480 train_time:1035738ms step_avg:980.81ms
step:1067/1480 train_time:1036746ms step_avg:980.84ms
step:1068/1480 train_time:1037747ms step_avg:980.86ms
step:1069/1480 train_time:1038748ms step_avg:980.88ms
step:1070/1480 train_time:1039755ms step_avg:980.90ms
step:1071/1480 train_time:1040769ms step_avg:980.93ms
step:1072/1480 train_time:1041761ms step_avg:980.94ms
step:1073/1480 train_time:1042762ms step_avg:980.96ms
step:1074/1480 train_time:1043757ms step_avg:980.98ms
step:1075/1480 train_time:1044762ms step_avg:981.00ms
step:1075/1480 val_loss:2.7784 train_time:1044802ms step_avg:981.03ms perplexity:6.8607 param_count:85,143,606
step:1076/1480 train_time:1045741ms step_avg:981.00ms
step:1077/1480 train_time:1046731ms step_avg:981.00ms
step:1078/1480 train_time:1047728ms step_avg:981.02ms
step:1079/1480 train_time:1048745ms step_avg:981.05ms
step:1080/1480 train_time:1049751ms step_avg:981.08ms
step:1081/1480 train_time:1050752ms step_avg:981.09ms
step:1082/1480 train_time:1051765ms step_avg:981.12ms
step:1083/1480 train_time:1052756ms step_avg:981.13ms
step:1084/1480 train_time:1053748ms step_avg:981.14ms
step:1085/1480 train_time:1054747ms step_avg:981.16ms
step:1086/1480 train_time:1055746ms step_avg:981.18ms
step:1087/1480 train_time:1056749ms step_avg:981.20ms
step:1088/1480 train_time:1057741ms step_avg:981.21ms
step:1089/1480 train_time:1058748ms step_avg:981.23ms
step:1090/1480 train_time:1059741ms step_avg:981.24ms
step:1091/1480 train_time:1060744ms step_avg:981.26ms
step:1092/1480 train_time:1061744ms step_avg:981.28ms
step:1093/1480 train_time:1062735ms step_avg:981.29ms
step:1094/1480 train_time:1063724ms step_avg:981.30ms
step:1095/1480 train_time:1064728ms step_avg:981.32ms
step:1096/1480 train_time:1065725ms step_avg:981.33ms
step:1097/1480 train_time:1066714ms step_avg:981.34ms
step:1098/1480 train_time:1067707ms step_avg:981.35ms
step:1099/1480 train_time:1068706ms step_avg:981.36ms
step:1100/1480 train_time:1069695ms step_avg:981.37ms
step:1100/1480 val_loss:2.7896 train_time:1069740ms step_avg:981.41ms perplexity:6.9145 param_count:85,143,606
step:1101/1480 train_time:1070685ms step_avg:981.38ms
step:1102/1480 train_time:1071689ms step_avg:981.40ms
step:1103/1480 train_time:1072696ms step_avg:981.42ms
step:1104/1480 train_time:1073697ms step_avg:981.44ms
step:1105/1480 train_time:1074691ms step_avg:981.45ms
step:1106/1480 train_time:1075688ms step_avg:981.47ms
step:1107/1480 train_time:1076702ms step_avg:981.50ms
step:1108/1480 train_time:1077705ms step_avg:981.52ms
step:1109/1480 train_time:1078694ms step_avg:981.52ms
step:1110/1480 train_time:1079689ms step_avg:981.54ms
step:1111/1480 train_time:1080691ms step_avg:981.55ms
step:1112/1480 train_time:1081681ms step_avg:981.56ms
step:1113/1480 train_time:1082686ms step_avg:981.58ms
step:1114/1480 train_time:1083686ms step_avg:981.60ms
step:1115/1480 train_time:1084676ms step_avg:981.61ms
step:1116/1480 train_time:1085669ms step_avg:981.62ms
step:1117/1480 train_time:1086670ms step_avg:981.63ms
step:1118/1480 train_time:1087681ms step_avg:981.66ms
step:1119/1480 train_time:1088685ms step_avg:981.68ms
step:1120/1480 train_time:1089678ms step_avg:981.69ms
step:1121/1480 train_time:1090676ms step_avg:981.71ms
step:1122/1480 train_time:1091661ms step_avg:981.71ms
step:1123/1480 train_time:1092654ms step_avg:981.72ms
step:1124/1480 train_time:1093644ms step_avg:981.73ms
step:1125/1480 train_time:1094634ms step_avg:981.73ms
step:1125/1480 val_loss:2.7833 train_time:1094678ms step_avg:981.77ms perplexity:6.8841 param_count:85,143,606
step:1126/1480 train_time:1095615ms step_avg:981.73ms
step:1127/1480 train_time:1096614ms step_avg:981.75ms
step:1128/1480 train_time:1097616ms step_avg:981.77ms
step:1129/1480 train_time:1098615ms step_avg:981.78ms
step:1130/1480 train_time:1099613ms step_avg:981.80ms
step:1131/1480 train_time:1100622ms step_avg:981.82ms
step:1132/1480 train_time:1101625ms step_avg:981.84ms
step:1133/1480 train_time:1102619ms step_avg:981.85ms
step:1134/1480 train_time:1103620ms step_avg:981.87ms
step:1135/1480 train_time:1104621ms step_avg:981.89ms
step:1136/1480 train_time:1105619ms step_avg:981.90ms
step:1137/1480 train_time:1106614ms step_avg:981.91ms
step:1138/1480 train_time:1107614ms step_avg:981.93ms
step:1139/1480 train_time:1108609ms step_avg:981.94ms
step:1140/1480 train_time:1109599ms step_avg:981.95ms
step:1141/1480 train_time:1110592ms step_avg:981.96ms
step:1142/1480 train_time:1111577ms step_avg:981.96ms
step:1143/1480 train_time:1112579ms step_avg:981.98ms
step:1144/1480 train_time:1113585ms step_avg:982.00ms
step:1145/1480 train_time:1114588ms step_avg:982.02ms
step:1146/1480 train_time:1115588ms step_avg:982.03ms
step:1147/1480 train_time:1116593ms step_avg:982.05ms
step:1148/1480 train_time:1117589ms step_avg:982.06ms
step:1149/1480 train_time:1118573ms step_avg:982.07ms
step:1150/1480 train_time:1119573ms step_avg:982.08ms
step:1150/1480 val_loss:2.7754 train_time:1119618ms step_avg:982.12ms perplexity:6.8468 param_count:85,143,606
step:1151/1480 train_time:1120581ms step_avg:982.10ms
step:1152/1480 train_time:1121585ms step_avg:982.12ms
step:1153/1480 train_time:1122575ms step_avg:982.13ms
step:1154/1480 train_time:1123573ms step_avg:982.14ms
step:1155/1480 train_time:1124575ms step_avg:982.16ms
step:1156/1480 train_time:1125571ms step_avg:982.17ms
step:1157/1480 train_time:1126567ms step_avg:982.19ms
step:1158/1480 train_time:1127562ms step_avg:982.20ms
step:1159/1480 train_time:1128556ms step_avg:982.21ms
step:1160/1480 train_time:1129559ms step_avg:982.23ms
step:1161/1480 train_time:1130552ms step_avg:982.23ms
step:1162/1480 train_time:1131544ms step_avg:982.24ms
step:1163/1480 train_time:1132536ms step_avg:982.25ms
step:1164/1480 train_time:1133531ms step_avg:982.26ms
step:1165/1480 train_time:1134522ms step_avg:982.27ms
step:1166/1480 train_time:1135527ms step_avg:982.29ms
step:1167/1480 train_time:1136532ms step_avg:982.31ms
step:1168/1480 train_time:1137540ms step_avg:982.33ms
step:1169/1480 train_time:1138536ms step_avg:982.34ms
step:1170/1480 train_time:1139544ms step_avg:982.37ms
step:1171/1480 train_time:1140544ms step_avg:982.38ms
step:1172/1480 train_time:1141543ms step_avg:982.40ms
step:1173/1480 train_time:1142594ms step_avg:982.45ms
step:1174/1480 train_time:1143593ms step_avg:982.47ms
step:1175/1480 train_time:1144594ms step_avg:982.48ms
step:1175/1480 val_loss:2.7808 train_time:1144637ms step_avg:982.52ms perplexity:6.8724 param_count:85,143,606
step:1176/1480 train_time:1145574ms step_avg:982.48ms
step:1177/1480 train_time:1146574ms step_avg:982.50ms
step:1178/1480 train_time:1147571ms step_avg:982.51ms
step:1179/1480 train_time:1148571ms step_avg:982.52ms
step:1180/1480 train_time:1149561ms step_avg:982.53ms
step:1181/1480 train_time:1150563ms step_avg:982.55ms
step:1182/1480 train_time:1151571ms step_avg:982.57ms
step:1183/1480 train_time:1152578ms step_avg:982.59ms
step:1184/1480 train_time:1153576ms step_avg:982.60ms
step:1185/1480 train_time:1154569ms step_avg:982.61ms
step:1186/1480 train_time:1155561ms step_avg:982.62ms
step:1187/1480 train_time:1156554ms step_avg:982.63ms
step:1188/1480 train_time:1157551ms step_avg:982.64ms
step:1189/1480 train_time:1158537ms step_avg:982.64ms
step:1190/1480 train_time:1159540ms step_avg:982.66ms
step:1191/1480 train_time:1160565ms step_avg:982.70ms
step:1192/1480 train_time:1161558ms step_avg:982.71ms
step:1193/1480 train_time:1162559ms step_avg:982.72ms
step:1194/1480 train_time:1163562ms step_avg:982.74ms
step:1195/1480 train_time:1164554ms step_avg:982.75ms
step:1196/1480 train_time:1165558ms step_avg:982.76ms
step:1197/1480 train_time:1166564ms step_avg:982.78ms
step:1198/1480 train_time:1167557ms step_avg:982.79ms
step:1199/1480 train_time:1168557ms step_avg:982.81ms
step:1200/1480 train_time:1169557ms step_avg:982.82ms
step:1200/1480 val_loss:2.7660 train_time:1169597ms step_avg:982.85ms perplexity:6.8021 param_count:85,143,606
step:1201/1480 train_time:1170548ms step_avg:982.83ms
step:1202/1480 train_time:1171550ms step_avg:982.84ms
step:1203/1480 train_time:1172538ms step_avg:982.85ms
step:1204/1480 train_time:1173536ms step_avg:982.86ms
step:1205/1480 train_time:1174549ms step_avg:982.89ms
step:1206/1480 train_time:1175539ms step_avg:982.89ms
step:1207/1480 train_time:1176534ms step_avg:982.90ms
step:1208/1480 train_time:1177524ms step_avg:982.91ms
step:1209/1480 train_time:1178521ms step_avg:982.92ms
step:1210/1480 train_time:1179531ms step_avg:982.94ms
step:1211/1480 train_time:1180531ms step_avg:982.96ms
step:1212/1480 train_time:1181534ms step_avg:982.97ms
step:1213/1480 train_time:1182535ms step_avg:982.99ms
step:1214/1480 train_time:1183531ms step_avg:983.00ms
step:1215/1480 train_time:1184537ms step_avg:983.02ms
step:1216/1480 train_time:1185530ms step_avg:983.03ms
step:1217/1480 train_time:1186525ms step_avg:983.04ms
step:1218/1480 train_time:1187529ms step_avg:983.05ms
step:1219/1480 train_time:1188532ms step_avg:983.07ms
step:1220/1480 train_time:1189525ms step_avg:983.08ms
step:1221/1480 train_time:1190539ms step_avg:983.10ms
step:1222/1480 train_time:1191534ms step_avg:983.11ms
step:1223/1480 train_time:1192527ms step_avg:983.12ms
step:1224/1480 train_time:1193522ms step_avg:983.13ms
step:1225/1480 train_time:1194534ms step_avg:983.16ms
step:1225/1480 val_loss:2.7724 train_time:1194579ms step_avg:983.19ms perplexity:6.8325 param_count:85,143,606
step:1226/1480 train_time:1195526ms step_avg:983.16ms
step:1227/1480 train_time:1196527ms step_avg:983.18ms
step:1228/1480 train_time:1197537ms step_avg:983.20ms
step:1229/1480 train_time:1198539ms step_avg:983.21ms
step:1230/1480 train_time:1199537ms step_avg:983.23ms
step:1231/1480 train_time:1200555ms step_avg:983.26ms
step:1232/1480 train_time:1201559ms step_avg:983.27ms
step:1233/1480 train_time:1202560ms step_avg:983.29ms
step:1234/1480 train_time:1203559ms step_avg:983.30ms
step:1235/1480 train_time:1204556ms step_avg:983.31ms
step:1236/1480 train_time:1205547ms step_avg:983.32ms
step:1237/1480 train_time:1206551ms step_avg:983.33ms
step:1238/1480 train_time:1207552ms step_avg:983.35ms
step:1239/1480 train_time:1208546ms step_avg:983.36ms
step:1240/1480 train_time:1209561ms step_avg:983.38ms
step:1241/1480 train_time:1210564ms step_avg:983.40ms
step:1242/1480 train_time:1211575ms step_avg:983.42ms
step:1243/1480 train_time:1212570ms step_avg:983.43ms
step:1244/1480 train_time:1213569ms step_avg:983.44ms
step:1245/1480 train_time:1214580ms step_avg:983.47ms
step:1246/1480 train_time:1215586ms step_avg:983.48ms
step:1247/1480 train_time:1216604ms step_avg:983.51ms
step:1248/1480 train_time:1217599ms step_avg:983.52ms
step:1249/1480 train_time:1218584ms step_avg:983.52ms
step:1250/1480 train_time:1219590ms step_avg:983.54ms
step:1250/1480 val_loss:2.7646 train_time:1219634ms step_avg:983.58ms perplexity:6.7956 param_count:85,143,606
step:1251/1480 train_time:1220583ms step_avg:983.55ms
step:1252/1480 train_time:1221571ms step_avg:983.55ms
step:1253/1480 train_time:1222583ms step_avg:983.57ms
step:1254/1480 train_time:1223579ms step_avg:983.58ms
step:1255/1480 train_time:1224582ms step_avg:983.60ms
step:1256/1480 train_time:1225581ms step_avg:983.61ms
step:1257/1480 train_time:1226579ms step_avg:983.62ms
step:1258/1480 train_time:1227582ms step_avg:983.64ms
step:1259/1480 train_time:1228564ms step_avg:983.64ms
step:1260/1480 train_time:1229575ms step_avg:983.66ms
step:1261/1480 train_time:1230574ms step_avg:983.67ms
step:1262/1480 train_time:1231573ms step_avg:983.68ms
step:1263/1480 train_time:1232572ms step_avg:983.70ms
step:1264/1480 train_time:1233563ms step_avg:983.70ms
step:1265/1480 train_time:1234610ms step_avg:983.75ms
step:1266/1480 train_time:1235581ms step_avg:983.74ms
step:1267/1480 train_time:1236592ms step_avg:983.76ms
step:1268/1480 train_time:1237625ms step_avg:983.80ms
step:1269/1480 train_time:1238630ms step_avg:983.82ms
step:1270/1480 train_time:1239642ms step_avg:983.84ms
step:1271/1480 train_time:1240638ms step_avg:983.85ms
step:1272/1480 train_time:1241631ms step_avg:983.86ms
step:1273/1480 train_time:1242632ms step_avg:983.87ms
step:1274/1480 train_time:1243635ms step_avg:983.89ms
step:1275/1480 train_time:1244634ms step_avg:983.90ms
step:1275/1480 val_loss:2.7668 train_time:1244675ms step_avg:983.93ms perplexity:6.8059 param_count:85,143,606
step:1276/1480 train_time:1245623ms step_avg:983.90ms
step:1277/1480 train_time:1246607ms step_avg:983.90ms
step:1278/1480 train_time:1247606ms step_avg:983.92ms
step:1279/1480 train_time:1248609ms step_avg:983.93ms
step:1280/1480 train_time:1249601ms step_avg:983.94ms
step:1281/1480 train_time:1250603ms step_avg:983.95ms
step:1282/1480 train_time:1251605ms step_avg:983.97ms
step:1283/1480 train_time:1252596ms step_avg:983.97ms
step:1284/1480 train_time:1253600ms step_avg:983.99ms
step:1285/1480 train_time:1254603ms step_avg:984.00ms
step:1286/1480 train_time:1255601ms step_avg:984.01ms
step:1287/1480 train_time:1256594ms step_avg:984.02ms
step:1288/1480 train_time:1257597ms step_avg:984.04ms
step:1289/1480 train_time:1258598ms step_avg:984.05ms
step:1290/1480 train_time:1259598ms step_avg:984.06ms
step:1291/1480 train_time:1260608ms step_avg:984.08ms
step:1292/1480 train_time:1261613ms step_avg:984.10ms
step:1293/1480 train_time:1262623ms step_avg:984.12ms
step:1294/1480 train_time:1263616ms step_avg:984.12ms
step:1295/1480 train_time:1264628ms step_avg:984.15ms
step:1296/1480 train_time:1265622ms step_avg:984.15ms
step:1297/1480 train_time:1266647ms step_avg:984.19ms
step:1298/1480 train_time:1267641ms step_avg:984.19ms
step:1299/1480 train_time:1268637ms step_avg:984.20ms
step:1300/1480 train_time:1269635ms step_avg:984.21ms
step:1300/1480 val_loss:2.7717 train_time:1269679ms step_avg:984.25ms perplexity:6.8292 param_count:85,143,606
step:1301/1480 train_time:1270638ms step_avg:984.23ms
step:1302/1480 train_time:1271650ms step_avg:984.25ms
step:1303/1480 train_time:1272644ms step_avg:984.26ms
step:1304/1480 train_time:1273649ms step_avg:984.27ms
step:1305/1480 train_time:1274655ms step_avg:984.29ms
step:1306/1480 train_time:1275657ms step_avg:984.30ms
step:1307/1480 train_time:1276648ms step_avg:984.31ms
step:1308/1480 train_time:1277633ms step_avg:984.31ms
step:1309/1480 train_time:1278630ms step_avg:984.32ms
step:1310/1480 train_time:1279641ms step_avg:984.34ms
step:1311/1480 train_time:1280635ms step_avg:984.35ms
step:1312/1480 train_time:1281637ms step_avg:984.36ms
step:1313/1480 train_time:1282648ms step_avg:984.38ms
step:1314/1480 train_time:1283648ms step_avg:984.39ms
step:1315/1480 train_time:1284648ms step_avg:984.40ms
step:1316/1480 train_time:1285647ms step_avg:984.42ms
step:1317/1480 train_time:1286663ms step_avg:984.44ms
step:1318/1480 train_time:1287663ms step_avg:984.45ms
step:1319/1480 train_time:1288674ms step_avg:984.47ms
step:1320/1480 train_time:1289683ms step_avg:984.49ms
step:1321/1480 train_time:1290679ms step_avg:984.50ms
step:1322/1480 train_time:1291675ms step_avg:984.51ms
step:1323/1480 train_time:1292673ms step_avg:984.52ms
step:1324/1480 train_time:1293658ms step_avg:984.52ms
step:1325/1480 train_time:1294654ms step_avg:984.53ms
step:1325/1480 val_loss:2.7655 train_time:1294699ms step_avg:984.56ms perplexity:6.7996 param_count:85,143,606
step:1326/1480 train_time:1295646ms step_avg:984.53ms
step:1327/1480 train_time:1296662ms step_avg:984.56ms
step:1328/1480 train_time:1297660ms step_avg:984.57ms
step:1329/1480 train_time:1298653ms step_avg:984.57ms
step:1330/1480 train_time:1299655ms step_avg:984.59ms
step:1331/1480 train_time:1300656ms step_avg:984.60ms
step:1332/1480 train_time:1301656ms step_avg:984.61ms
step:1333/1480 train_time:1302650ms step_avg:984.62ms
step:1334/1480 train_time:1303668ms step_avg:984.64ms
step:1335/1480 train_time:1304659ms step_avg:984.65ms
step:1336/1480 train_time:1305684ms step_avg:984.68ms
step:1337/1480 train_time:1306694ms step_avg:984.70ms
step:1338/1480 train_time:1307715ms step_avg:984.72ms
step:1339/1480 train_time:1308733ms step_avg:984.75ms
step:1340/1480 train_time:1309753ms step_avg:984.78ms
step:1341/1480 train_time:1310755ms step_avg:984.79ms
step:1342/1480 train_time:1311755ms step_avg:984.80ms
step:1343/1480 train_time:1312746ms step_avg:984.81ms
step:1344/1480 train_time:1313750ms step_avg:984.82ms
step:1345/1480 train_time:1314742ms step_avg:984.83ms
step:1346/1480 train_time:1315749ms step_avg:984.84ms
step:1347/1480 train_time:1316743ms step_avg:984.85ms
step:1348/1480 train_time:1317746ms step_avg:984.86ms
step:1349/1480 train_time:1318752ms step_avg:984.88ms
step:1350/1480 train_time:1319761ms step_avg:984.90ms
step:1350/1480 val_loss:2.7609 train_time:1319802ms step_avg:984.93ms perplexity:6.7782 param_count:85,143,606
step:1351/1480 train_time:1320785ms step_avg:984.93ms
step:1352/1480 train_time:1321796ms step_avg:984.95ms
step:1353/1480 train_time:1322798ms step_avg:984.96ms
step:1354/1480 train_time:1323795ms step_avg:984.97ms
step:1355/1480 train_time:1324793ms step_avg:984.98ms
step:1356/1480 train_time:1325796ms step_avg:984.99ms
step:1357/1480 train_time:1326791ms step_avg:985.00ms
step:1358/1480 train_time:1327809ms step_avg:985.02ms
step:1359/1480 train_time:1328811ms step_avg:985.03ms
step:1360/1480 train_time:1329808ms step_avg:985.04ms
step:1361/1480 train_time:1330803ms step_avg:985.05ms
step:1362/1480 train_time:1331819ms step_avg:985.07ms
step:1363/1480 train_time:1332813ms step_avg:985.08ms
step:1364/1480 train_time:1333840ms step_avg:985.11ms
step:1365/1480 train_time:1334838ms step_avg:985.12ms
step:1366/1480 train_time:1335835ms step_avg:985.13ms
step:1367/1480 train_time:1336828ms step_avg:985.14ms
step:1368/1480 train_time:1337830ms step_avg:985.15ms
step:1369/1480 train_time:1338820ms step_avg:985.15ms
step:1370/1480 train_time:1339823ms step_avg:985.16ms
step:1371/1480 train_time:1340828ms step_avg:985.18ms
step:1372/1480 train_time:1341840ms step_avg:985.20ms
step:1373/1480 train_time:1342839ms step_avg:985.21ms
step:1374/1480 train_time:1343840ms step_avg:985.22ms
step:1375/1480 train_time:1344840ms step_avg:985.23ms
step:1375/1480 val_loss:2.7646 train_time:1344885ms step_avg:985.26ms perplexity:6.7954 param_count:85,143,606
step:1376/1480 train_time:1345824ms step_avg:985.23ms
step:1377/1480 train_time:1346831ms step_avg:985.25ms
step:1378/1480 train_time:1347859ms step_avg:985.28ms
step:1379/1480 train_time:1348856ms step_avg:985.29ms
step:1380/1480 train_time:1349858ms step_avg:985.30ms
step:1381/1480 train_time:1350851ms step_avg:985.30ms
step:1382/1480 train_time:1351857ms step_avg:985.32ms
step:1383/1480 train_time:1352854ms step_avg:985.33ms
step:1384/1480 train_time:1353899ms step_avg:985.37ms
step:1385/1480 train_time:1354898ms step_avg:985.38ms
step:1386/1480 train_time:1355892ms step_avg:985.39ms
step:1387/1480 train_time:1356907ms step_avg:985.41ms
step:1388/1480 train_time:1357904ms step_avg:985.42ms
step:1389/1480 train_time:1358902ms step_avg:985.43ms
step:1390/1480 train_time:1359897ms step_avg:985.43ms
step:1391/1480 train_time:1360905ms step_avg:985.45ms
step:1392/1480 train_time:1361886ms step_avg:985.45ms
step:1393/1480 train_time:1362878ms step_avg:985.45ms
step:1394/1480 train_time:1363869ms step_avg:985.45ms
step:1395/1480 train_time:1364863ms step_avg:985.46ms
step:1396/1480 train_time:1365871ms step_avg:985.48ms
step:1397/1480 train_time:1366881ms step_avg:985.49ms
step:1398/1480 train_time:1367884ms step_avg:985.51ms
step:1399/1480 train_time:1368886ms step_avg:985.52ms
step:1400/1480 train_time:1369872ms step_avg:985.52ms
step:1400/1480 val_loss:2.7606 train_time:1369917ms step_avg:985.55ms perplexity:6.7768 param_count:85,143,606
step:1401/1480 train_time:1370861ms step_avg:985.52ms
step:1402/1480 train_time:1371868ms step_avg:985.54ms
step:1403/1480 train_time:1372863ms step_avg:985.54ms
step:1404/1480 train_time:1373858ms step_avg:985.55ms
step:1405/1480 train_time:1374872ms step_avg:985.57ms
step:1406/1480 train_time:1375866ms step_avg:985.58ms
step:1407/1480 train_time:1376875ms step_avg:985.59ms
step:1408/1480 train_time:1377885ms step_avg:985.61ms
step:1409/1480 train_time:1378882ms step_avg:985.62ms
step:1410/1480 train_time:1379877ms step_avg:985.63ms
step:1411/1480 train_time:1380880ms step_avg:985.64ms
step:1412/1480 train_time:1381877ms step_avg:985.65ms
step:1413/1480 train_time:1382871ms step_avg:985.65ms
step:1414/1480 train_time:1383871ms step_avg:985.66ms
step:1415/1480 train_time:1384879ms step_avg:985.68ms
step:1416/1480 train_time:1385863ms step_avg:985.68ms
step:1417/1480 train_time:1386847ms step_avg:985.68ms
step:1418/1480 train_time:1387853ms step_avg:985.69ms
step:1419/1480 train_time:1388851ms step_avg:985.70ms
step:1420/1480 train_time:1389862ms step_avg:985.72ms
step:1421/1480 train_time:1390858ms step_avg:985.72ms
step:1422/1480 train_time:1391861ms step_avg:985.74ms
step:1423/1480 train_time:1392856ms step_avg:985.74ms
step:1424/1480 train_time:1393851ms step_avg:985.75ms
step:1425/1480 train_time:1394844ms step_avg:985.76ms
step:1425/1480 val_loss:2.7502 train_time:1394889ms step_avg:985.79ms perplexity:6.7280 param_count:85,143,606
step:1426/1480 train_time:1395843ms step_avg:985.77ms
step:1427/1480 train_time:1396843ms step_avg:985.77ms
step:1428/1480 train_time:1397861ms step_avg:985.80ms
step:1429/1480 train_time:1398865ms step_avg:985.81ms
step:1430/1480 train_time:1399846ms step_avg:985.81ms
step:1431/1480 train_time:1400842ms step_avg:985.81ms
step:1432/1480 train_time:1401842ms step_avg:985.82ms
step:1433/1480 train_time:1402839ms step_avg:985.83ms
step:1434/1480 train_time:1403834ms step_avg:985.84ms
step:1435/1480 train_time:1404826ms step_avg:985.84ms
step:1436/1480 train_time:1405832ms step_avg:985.86ms
step:1437/1480 train_time:1406822ms step_avg:985.86ms
step:1438/1480 train_time:1407857ms step_avg:985.89ms
step:1439/1480 train_time:1408843ms step_avg:985.89ms
step:1440/1480 train_time:1409833ms step_avg:985.90ms
step:1441/1480 train_time:1410843ms step_avg:985.91ms
step:1442/1480 train_time:1411847ms step_avg:985.93ms
step:1443/1480 train_time:1412835ms step_avg:985.93ms
step:1444/1480 train_time:1413833ms step_avg:985.94ms
step:1445/1480 train_time:1414827ms step_avg:985.94ms
step:1446/1480 train_time:1415839ms step_avg:985.96ms
step:1447/1480 train_time:1416843ms step_avg:985.97ms
step:1448/1480 train_time:1417856ms step_avg:985.99ms
step:1449/1480 train_time:1418870ms step_avg:986.01ms
step:1450/1480 train_time:1419880ms step_avg:986.03ms
step:1450/1480 val_loss:2.7465 train_time:1419926ms step_avg:986.06ms perplexity:6.7110 param_count:85,143,606
step:1451/1480 train_time:1420877ms step_avg:986.04ms
step:1452/1480 train_time:1421919ms step_avg:986.07ms
step:1453/1480 train_time:1422915ms step_avg:986.08ms
step:1454/1480 train_time:1423907ms step_avg:986.09ms
step:1455/1480 train_time:1424898ms step_avg:986.09ms
step:1456/1480 train_time:1425899ms step_avg:986.10ms
step:1457/1480 train_time:1426909ms step_avg:986.12ms
step:1458/1480 train_time:1427905ms step_avg:986.12ms
step:1459/1480 train_time:1428902ms step_avg:986.13ms
step:1460/1480 train_time:1429899ms step_avg:986.14ms
step:1461/1480 train_time:1430893ms step_avg:986.14ms
step:1462/1480 train_time:1431879ms step_avg:986.14ms
step:1463/1480 train_time:1432887ms step_avg:986.16ms
step:1464/1480 train_time:1433906ms step_avg:986.18ms
step:1465/1480 train_time:1434887ms step_avg:986.18ms
step:1466/1480 train_time:1435889ms step_avg:986.19ms
step:1467/1480 train_time:1436896ms step_avg:986.20ms
step:1468/1480 train_time:1437900ms step_avg:986.21ms
step:1469/1480 train_time:1438900ms step_avg:986.22ms
step:1470/1480 train_time:1439910ms step_avg:986.24ms
step:1471/1480 train_time:1440902ms step_avg:986.24ms
step:1472/1480 train_time:1441893ms step_avg:986.25ms
step:1473/1480 train_time:1442895ms step_avg:986.26ms
step:1474/1480 train_time:1443890ms step_avg:986.26ms
step:1475/1480 train_time:1444891ms step_avg:986.27ms
step:1475/1480 val_loss:2.7499 train_time:1444936ms step_avg:986.30ms perplexity:6.7267 param_count:85,143,606
step:1476/1480 train_time:1445887ms step_avg:986.28ms
step:1477/1480 train_time:1446888ms step_avg:986.29ms
step:1478/1480 train_time:1447907ms step_avg:986.31ms
step:1479/1480 train_time:1448906ms step_avg:986.32ms
step:1480/1480 train_time:1449901ms step_avg:986.33ms
step:1480/1480 val_loss:2.7509 train_time:1449941ms step_avg:986.35ms perplexity:6.7316 param_count:85,143,606
peak memory consumption: 12942 MiB
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment