This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| An Minimal Implementation of IMM (Inductive Moment Matching) | |
| """ | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| def compute_mmd_loss_fully_vectorized( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ## Note | |
| ## if use vllm in same gpu, remember to set a low gpu_memory_utilization to avoid OOM | |
| ## For larger model please consider to use multi-GPU or CPU offloading | |
| ## AnySchedule: https://github.com/KohakuBlueleaf/AnySchedule | |
| ## LyCORIS: https://github.com/KohakuBlueleaf/LyCORIS | |
| ## Following code can perform reasonable training on Llama-3.2-1B-Instruct model with GSM8K dataset | |
| ## With noticable improvement on each reward function | |
| from itertools import chain | |
| import re |