# inspired by https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899
# discussion: https://github.com/linkedin/Liger-Kernel/issues/227
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

def cdiv(x: int, y: int):
    return (x + y - 1) // y

def next_power_of_2(n: int):
    """Return the smallest power of 2 greater than or equal to n"""
    n -= 1
    n |= n >> 1
    n |= n >> 2
    n |= n >> 4
    n |= n >> 8
    n |= n >> 16
    n |= n >> 32
    n += 1
    return n

class ChunkedCE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, _input, weight, target, bias=None, softcap_value=None, chunk_size=None):
        BT, H = _input.shape
        
        if chunk_size is None:
            chunk_size = 1024
        if chunk_size == 'auto':
            V = weight.shape[0]
            inc_factor = cdiv(V, H)  # (V + H - 1) // H
            chunk_size = next_power_of_2(cdiv(BT, inc_factor))  # (BT + inc_factor - 1) // inc_factor
        chunks = cdiv(BT, chunk_size)  # (BT + chunk_size - 1) // chunk_size
        
        total_n_non_ignore = (target != -100).sum()
        def compute_loss(input_chunk, weight, bias, target):
            # if bias is not None:
            #     logits = torch.addmm(bias, input_chunk, weight.t())
            # else:
            #     logits = torch.matmul(input_chunk, weight.t())

            # more memory efficient when bias is set
            logits = F.linear(input_chunk, weight, bias)

            if softcap_value is not None:
                logits = torch.tanh(logits / softcap_value) * softcap_value
            logits = logits.float()
            loss = F.cross_entropy(logits, target)
            return loss

        grad_weight = torch.zeros_like(weight)
        grad_input = torch.zeros_like(_input)
        grad_bias = torch.zeros_like(bias) if bias is not None else None
        loss_acc = torch.zeros((), device=_input.device)
        
        @torch.compile(dynamic=True, options={"shape_padding": True})
        def accumulate_chunk(input_chunk, target_chunk):
            if bias is not None:
                (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))(input_chunk, weight, bias, target_chunk)
            else:
                (chunk_grad_input, chunk_grad_weight), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1))(input_chunk, weight, None, target_chunk)
                chunk_grad_bias = None
            
            n_non_ignore = (target_chunk != -100).sum().item()
            grad_weight.add_(chunk_grad_weight * n_non_ignore)
            if grad_bias is not None: grad_bias.add_(chunk_grad_bias * n_non_ignore)
            loss_acc.add_(chunk_loss * n_non_ignore)
            return chunk_grad_input * n_non_ignore
        
        accu_len = 0
        for chunk_id in range(chunks):
            start_idx = chunk_id * chunk_size
            end_idx = min((chunk_id + 1) * chunk_size, BT)
            input_chunk = _input[start_idx: end_idx]
            target_chunk = target[start_idx: end_idx]
            grad_input[accu_len: accu_len + input_chunk.shape[0]] = accumulate_chunk(input_chunk, target_chunk)
            accu_len += input_chunk.shape[0]
            
        
        ctx.save_for_backward(
            grad_input / total_n_non_ignore,
            grad_weight / total_n_non_ignore,
            grad_bias / total_n_non_ignore if grad_bias is not None else None,
        )
        return loss_acc / total_n_non_ignore

    @staticmethod
    def backward(ctx, grad_output):
        (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
        return (grad_input, grad_weight, None, grad_bias, None, None)
    

class CompiledFusedLinearCrossEntropyLoss(CrossEntropyLoss):
    def __init__(self, *args, **kwargs):
        super(CompiledFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)

    def forward(self, lin_weight, _input, target, bias=None, softcap_value=None, chunk_size=None):
        return ChunkedCE.apply(
            _input, lin_weight, target, bias, softcap_value, chunk_size
        )


if __name__ == "__main__":
    torch.set_default_device('cuda')
    from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
    
    B, T, D, V = 4, 2048, 4096, 128256
    B, T, D, V = 4, 4096, 3584, 256000
    model = nn.Linear(D, V, bias=True).to(torch.bfloat16)
    x = torch.randn(B, T, D, requires_grad=True, dtype=torch.bfloat16)
    label = torch.randint(0, V, (B, T)).to(torch.int64)
    def f(m, x, label):
        out = F.cross_entropy(m(x).view(-1, V), label.view(-1))
        out.backward()
        return out

    def chunked_f(m, x, label, chunk_size=None):
        out = ChunkedCE.apply(x.view(-1, D), m.weight, label.view(-1), m.bias, None, chunk_size)
        out.backward()
        return out

    def ligerf(m, x, label):
        out = LigerFusedLinearCrossEntropyFunction.apply(x.view(-1, D), m.weight,label.view(-1), model.bias)
        out.backward()
        return out

    def bench(f, name=None, iters=100, warmup=5, display=True, profile=False, profile_mem=False):
        from triton.testing import do_bench
        for _ in range(warmup):
            f()

        if profile_mem:
            torch.cuda.memory._record_memory_history()
            f()
            torch.cuda.memory._dump_snapshot(f"{name if name is not None else 'memory'}.pickle")
        if profile:
            with torch.profiler.profile() as prof:
                f()
            prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")

        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_peak_memory_stats()
        ms_per_iter = do_bench(lambda: f())
        if name is None:
            res = ms_per_iter
        else:
            res= f"{name}: {ms_per_iter:.3f}ms"
        if display:
            print(res)
            print("Peak mem: ", torch.cuda.max_memory_allocated()/1e9)
            print()
        return res

    opt_f = torch.compile(f)
    # bench(lambda: f(model, x, label), name='eager lce (non-chunked)')
    # bench(lambda: opt_f(model, x, label), name='compile lce (non-chunked)')
    bench(lambda: ligerf(model, x, label), name='liger lce')
    bench(lambda: chunked_f(model, x, label, chunk_size=1024), name='compile lce (chunk 1024)')
    # bench(lambda: chunked_f(model, x, label, chunk_size='auto'), name='compile lce (chunk auto)')