CompressAI Pytorch Lightning example
Last active
May 6, 2022 21:13
-
-
Save YodaEmbedding/8d1d32748cc546ce49ee9dea82c6f2aa to your computer and use it in GitHub Desktop.
CompressAI Pytorch Lightning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
from compressai.datasets import ImageFolder | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
class CLICDataModule(pl.LightningDataModule): | |
def __init__(self, data_dir, patch_size, **dataloader_kwargs): | |
super().__init__() | |
self.data_dir = data_dir | |
self.train_transform = transforms.Compose( | |
[ | |
transforms.RandomCrop(patch_size), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomVerticalFlip(), | |
transforms.ToTensor(), | |
] | |
) | |
self.val_transform = transforms.Compose( | |
[ | |
transforms.CenterCrop(patch_size), | |
transforms.ToTensor(), | |
] | |
) | |
self.test_transform = self.val_transform | |
self.dataloader_kwargs = dataloader_kwargs | |
def prepare_data(self): | |
pass | |
def setup(self, stage=None): | |
self.train_dataset = ImageFolder( | |
self.data_dir, split="train", transform=self.train_transform | |
) | |
self.val_dataset = ImageFolder( | |
self.data_dir, split="valid", transform=self.val_transform | |
) | |
self.test_dataset = ImageFolder( | |
self.data_dir, split="test", transform=self.test_transform | |
) | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, shuffle=True, **self.dataloader_kwargs | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.val_dataset, shuffle=False, **self.dataloader_kwargs | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.test_dataset, shuffle=False, **self.dataloader_kwargs | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, Optional | |
import pytorch_lightning as pl | |
import torch | |
import torch.optim as optim | |
from omegaconf import OmegaConf | |
from sfu_compression.losses import RateDistortionLoss | |
from sfu_compression.models import SFUDenoiseScalable | |
from sfu_compression.utils import ( | |
create_noise_model, | |
git_branch_name, | |
git_common_ancestor_hash, | |
git_current_hash, | |
) | |
class LitSFUDenoiseScalable(pl.LightningModule): | |
def __init__( | |
self, | |
conf: Optional[OmegaConf] = None, | |
**kwargs, | |
): | |
super().__init__() | |
self.save_hyperparameters(conf) | |
self.save_hyperparameters(kwargs) | |
self.model = SFUDenoiseScalable( | |
N=self.hparams.architecture.num_channels, | |
BASE_N=self.hparams.architecture.num_base_channels, | |
) | |
self.criterion = RateDistortionLoss( | |
lmbda=self.hparams.training.lmbda, | |
w1d=self.hparams.training.w1d, | |
w2d=self.hparams.training.w2d, | |
w3d=self.hparams.training.w3d, | |
w1r=self.hparams.training.w1r, | |
w2r=self.hparams.training.w2r, | |
w3r=self.hparams.training.w3r, | |
) | |
self.noise_model = create_noise_model(self.hparams.noise_model) | |
self.automatic_optimization = False | |
def forward(self, x): | |
# TODO compress, decompress? | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
x = batch | |
x_noise = self.noise_model(x) | |
optimizer, aux_optimizer = self.optimizers() | |
optimizer.zero_grad() | |
aux_optimizer.zero_grad() | |
out_net = self.model(x_noise) | |
out_criterion = self.criterion(out_net, {"x": x_noise, "x_denoise": x}) | |
loss = out_criterion["loss"] | |
self.manual_backward(loss) | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), self.hparams.training.clip_max_norm | |
) | |
optimizer.step() | |
aux_loss = self.model.aux_loss() | |
self.manual_backward(aux_loss) | |
aux_optimizer.step() | |
log_dict = {**out_criterion, "aux_loss": aux_loss} | |
log_dict = {f"train/{k}": v for k, v in log_dict.items()} | |
self.log_dict(log_dict) | |
def validation_step(self, batch, batch_idx): | |
x = batch | |
x_noise = self.noise_model(x) | |
out_net = self.model(x_noise) | |
out_criterion = self.criterion(out_net, {"x": x_noise, "x_denoise": x}) | |
aux_loss = self.model.aux_loss() | |
log_dict = {**out_criterion, "aux_loss": aux_loss} | |
log_dict = {f"val/{k}": v for k, v in log_dict.items()} | |
log_dict["val_loss"] = out_criterion["loss"] | |
self.log_dict(log_dict) | |
def validation_epoch_end(self, outputs): | |
sch = self.lr_schedulers() | |
if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau): | |
sch.step(self.trainer.callback_metrics["val/loss"]) | |
else: | |
raise Exception | |
def test_step(self, batch, batch_idx): | |
x = batch | |
x_noise = self.noise_model(x) | |
enc_dict = self.model.compress(x_noise) | |
encoded = [x[0] for x in enc_dict["strings"]] | |
result = self.model.decompress(**enc_dict) | |
x_hat = result["x_hat"].numpy()[0] | |
# TODO log metrics, etc; on_epoch, on_step | |
def configure_optimizers(self): | |
optimizer, aux_optimizer = configure_optimizers( | |
self.model, self.hparams.training | |
) | |
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") | |
return ( | |
{ | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": lr_scheduler, | |
"monitor": "val/loss", | |
}, | |
}, | |
{ | |
"optimizer": aux_optimizer, | |
}, | |
) | |
def on_fit_start(self): | |
params = { | |
"git": { | |
"branch_name": git_branch_name(), | |
"hash": git_current_hash(), | |
"master_hash": git_common_ancestor_hash(), | |
}, | |
**self.hparams, | |
} | |
metrics = {"hp/metric": -1} | |
self.logger.log_hyperparams(params, metrics) | |
def on_load_checkpoint(self, checkpoint): | |
prefix = "model." | |
checkpoint["state_dict"] = { | |
f"{prefix}{k}": v for k, v in checkpoint["state_dict"].items() | |
} | |
def on_save_checkpoint(self, checkpoint): | |
prefix_len = len("model.") | |
checkpoint["state_dict"] = { | |
k[prefix_len:]: v for k, v in checkpoint["state_dict"].items() | |
} | |
def configure_optimizers(net, args): | |
"""Separate parameters for the main optimizer and the auxiliary optimizer. | |
Return two optimizers""" | |
parameters = { | |
n | |
for n, p in net.named_parameters() | |
if not n.endswith(".quantiles") and p.requires_grad | |
} | |
aux_parameters = { | |
n | |
for n, p in net.named_parameters() | |
if n.endswith(".quantiles") and p.requires_grad | |
} | |
# Make sure we don't have an intersection of parameters | |
params_dict = dict(net.named_parameters()) | |
inter_params = parameters & aux_parameters | |
union_params = parameters | aux_parameters | |
assert len(inter_params) == 0 | |
assert len(union_params) - len(params_dict.keys()) == 0 | |
optimizer = optim.Adam( | |
(params_dict[n] for n in sorted(parameters)), | |
lr=args.learning_rate, | |
) | |
aux_optimizer = optim.Adam( | |
(params_dict[n] for n in sorted(aux_parameters)), | |
lr=args.aux_learning_rate, | |
) | |
return optimizer, aux_optimizer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
import torch | |
from omegaconf import OmegaConf | |
from pytorch_lightning.callbacks import ( | |
EarlyStopping, | |
LearningRateMonitor, | |
ModelCheckpoint, | |
) | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from torchinfo import summary | |
from sfu_compression.datasets import CLICDataModule | |
from sfu_compression.models import LitSFUDenoiseScalable | |
from sfu_compression.utils import parse_args_training | |
def load_model_from_args(args: OmegaConf): | |
continue_from = args.training_params.continue_from | |
if continue_from != "": | |
checkpoint_path = f"{continue_from}/checkpoints/last.ckpt" | |
return LitSFUDenoiseScalable.load_from_checkpoint(checkpoint_path) | |
conf = OmegaConf.merge(args.hparams, {"noise_model": args.noise_model}) | |
return LitSFUDenoiseScalable(conf=conf) | |
def main(): | |
args = parse_args_training() | |
if args.training_params.seed is not None: | |
pl.seed_everything(args.training_params.seed) | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.enabled = True | |
data_module = CLICDataModule( | |
data_dir=args.other.dataset, | |
patch_size=args.training_params.patch_size, | |
batch_size=args.training_params.batch_size, | |
num_workers=args.training_params.num_workers, | |
pin_memory=True, | |
) | |
model = load_model_from_args(args) | |
# Show network with layer sizes. | |
h, w = args.training_params.patch_size | |
empty_img = (1, 3, h, w) | |
summary(model.model, [empty_img]) | |
checkpoint_callback = ModelCheckpoint( | |
monitor="val/loss", | |
filename="{epoch:04d}-{val_loss:.2f}", | |
save_last=True, | |
save_top_k=1, | |
mode="min", | |
) | |
early_stopping_callback = EarlyStopping("val/loss", patience=15) | |
lr_monitor_callback = LearningRateMonitor(logging_interval="epoch") | |
tb_logger = TensorBoardLogger( | |
save_dir="lightning_logs", | |
name="", | |
default_hp_metric=False, | |
) | |
trainer_kwargs = dict( | |
callbacks=[ | |
checkpoint_callback, | |
early_stopping_callback, | |
lr_monitor_callback, | |
], | |
logger=tb_logger, | |
) | |
trainer_kwargs = {**args.pytorch_lightning_trainer, **trainer_kwargs} | |
trainer = pl.Trainer(**trainer_kwargs) | |
trainer.fit(model, datamodule=data_module) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment