# inspired by https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899 # discussion: https://github.com/linkedin/Liger-Kernel/issues/227 import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss def cdiv(x: int, y: int): return (x + y - 1) // y def next_power_of_2(n: int): """Return the smallest power of 2 greater than or equal to n""" n -= 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 n |= n >> 32 n += 1 return n class ChunkedCE(torch.autograd.Function): @staticmethod def forward(ctx, _input, weight, target, bias=None, softcap_value=None, chunk_size=None): BT, H = _input.shape if chunk_size is None: chunk_size = 1024 if chunk_size == 'auto': V = weight.shape[0] inc_factor = cdiv(V, H) # (V + H - 1) // H chunk_size = next_power_of_2(cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor chunks = cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size total_n_non_ignore = (target != -100).sum() def compute_loss(input_chunk, weight, bias, target): # if bias is not None: # logits = torch.addmm(bias, input_chunk, weight.t()) # else: # logits = torch.matmul(input_chunk, weight.t()) # more memory efficient when bias is set logits = F.linear(input_chunk, weight, bias) if softcap_value is not None: logits = torch.tanh(logits / softcap_value) * softcap_value logits = logits.float() loss = F.cross_entropy(logits, target) return loss grad_weight = torch.zeros_like(weight) grad_input = torch.zeros_like(_input) grad_bias = torch.zeros_like(bias) if bias is not None else None loss_acc = torch.zeros((), device=_input.device) @torch.compile(dynamic=True, options={"shape_padding": True}) def accumulate_chunk(input_chunk, target_chunk): if bias is not None: (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))(input_chunk, weight, bias, target_chunk) else: (chunk_grad_input, chunk_grad_weight), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1))(input_chunk, weight, None, target_chunk) chunk_grad_bias = None n_non_ignore = (target_chunk != -100).sum().item() grad_weight.add_(chunk_grad_weight * n_non_ignore) if grad_bias is not None: grad_bias.add_(chunk_grad_bias * n_non_ignore) loss_acc.add_(chunk_loss * n_non_ignore) return chunk_grad_input * n_non_ignore accu_len = 0 for chunk_id in range(chunks): start_idx = chunk_id * chunk_size end_idx = min((chunk_id + 1) * chunk_size, BT) input_chunk = _input[start_idx: end_idx] target_chunk = target[start_idx: end_idx] grad_input[accu_len: accu_len + input_chunk.shape[0]] = accumulate_chunk(input_chunk, target_chunk) accu_len += input_chunk.shape[0] ctx.save_for_backward( grad_input / total_n_non_ignore, grad_weight / total_n_non_ignore, grad_bias / total_n_non_ignore if grad_bias is not None else None, ) return loss_acc / total_n_non_ignore @staticmethod def backward(ctx, grad_output): (grad_input, grad_weight, grad_bias) = ctx.saved_tensors return (grad_input, grad_weight, None, grad_bias, None, None) class CompiledFusedLinearCrossEntropyLoss(CrossEntropyLoss): def __init__(self, *args, **kwargs): super(CompiledFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs) def forward(self, lin_weight, _input, target, bias=None, softcap_value=None, chunk_size=None): return ChunkedCE.apply( _input, lin_weight, target, bias, softcap_value, chunk_size ) if __name__ == "__main__": torch.set_default_device('cuda') from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction B, T, D, V = 4, 2048, 4096, 128256 B, T, D, V = 4, 4096, 3584, 256000 model = nn.Linear(D, V, bias=True).to(torch.bfloat16) x = torch.randn(B, T, D, requires_grad=True, dtype=torch.bfloat16) label = torch.randint(0, V, (B, T)).to(torch.int64) def f(m, x, label): out = F.cross_entropy(m(x).view(-1, V), label.view(-1)) out.backward() return out def chunked_f(m, x, label, chunk_size=None): out = ChunkedCE.apply(x.view(-1, D), m.weight, label.view(-1), m.bias, None, chunk_size) out.backward() return out def ligerf(m, x, label): out = LigerFusedLinearCrossEntropyFunction.apply(x.view(-1, D), m.weight,label.view(-1), model.bias) out.backward() return out def bench(f, name=None, iters=100, warmup=5, display=True, profile=False, profile_mem=False): from triton.testing import do_bench for _ in range(warmup): f() if profile_mem: torch.cuda.memory._record_memory_history() f() torch.cuda.memory._dump_snapshot(f"{name if name is not None else 'memory'}.pickle") if profile: with torch.profiler.profile() as prof: f() prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats() ms_per_iter = do_bench(lambda: f()) if name is None: res = ms_per_iter else: res= f"{name}: {ms_per_iter:.3f}ms" if display: print(res) print("Peak mem: ", torch.cuda.max_memory_allocated()/1e9) print() return res opt_f = torch.compile(f) # bench(lambda: f(model, x, label), name='eager lce (non-chunked)') # bench(lambda: opt_f(model, x, label), name='compile lce (non-chunked)') bench(lambda: ligerf(model, x, label), name='liger lce') bench(lambda: chunked_f(model, x, label, chunk_size=1024), name='compile lce (chunk 1024)') # bench(lambda: chunked_f(model, x, label, chunk_size='auto'), name='compile lce (chunk auto)')