Created
September 29, 2021 04:03
-
-
Save ak64th/3569b592ba9855a99e4402c5c719de8b to your computer and use it in GitHub Desktop.
Use imgaug with detectron2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import copy | |
import os | |
import pathlib | |
from datetime import datetime | |
import imgaug.augmenters as iaa | |
import imgaug.random as ia_random | |
import numpy as np | |
import torch | |
from detectron2.config import get_cfg | |
from detectron2.data import build_detection_train_loader, build_detection_test_loader | |
from detectron2.data.detection_utils import ( | |
read_image, | |
check_image_size, | |
annotations_to_instances, | |
filter_empty_instances, | |
) | |
from detectron2.engine import DefaultTrainer | |
from detectron2.engine import launch | |
from detectron2.evaluation import COCOEvaluator | |
from detectron2.structures import BoxMode | |
from detectron2.utils.logger import setup_logger | |
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage | |
from imgaug.augmentables.polys import Polygon, PolygonsOnImage | |
class Mapper: | |
def __init__( | |
self, | |
augmenter: iaa.Augmenter, | |
is_train: bool = True, | |
image_format='BGR', | |
): | |
self.augmenter = augmenter | |
self.is_train = is_train | |
self.image_format = image_format | |
def __call__(self, dataset_dict): | |
dataset_dict = copy.deepcopy(dataset_dict) | |
dataset_dict.pop('sem_seg_file_name', None) # no need | |
image = read_image(dataset_dict["file_name"], format=self.image_format) | |
check_image_size(dataset_dict, image) | |
# FIXME: resize the image? | |
deterministic = self.augmenter.to_deterministic() | |
augmented_image = deterministic.augment_image(image) | |
dataset_dict['image'] = torch.as_tensor( | |
np.ascontiguousarray(augmented_image.transpose(2, 0, 1)) | |
) # use torch.Tensor for efficiency | |
if not self.is_train or 'annotations' not in dataset_dict: | |
dataset_dict.pop('annotations', None) | |
return dataset_dict | |
for anno in dataset_dict['annotations']: | |
anno.pop('keypoints', None) # no need | |
# transform bounding boxes and segmentations | |
annos = [ | |
_transform_annotation(obj, image.shape, deterministic) | |
for obj in dataset_dict.pop("annotations") | |
if obj.get('iscrowd', 0) == 0 | |
] | |
# build the Instances structure | |
instances = annotations_to_instances( | |
annos, augmented_image.shape, mask_format='polygon' | |
) | |
dataset_dict["instances"] = filter_empty_instances(instances) | |
return dataset_dict | |
def _transform_annotation(annotation, image_shape, augmentation: iaa.Augmenter): | |
assert augmentation.deterministic, 'Augmenter instance not deterministic.' | |
# transform the bounding box | |
bbox = BoxMode.convert(annotation['bbox'], annotation['bbox_mode'], BoxMode.XYXY_ABS) | |
_bbox = augmentation.augment_bounding_boxes( | |
BoundingBoxesOnImage([BoundingBox(*bbox)], shape=image_shape) | |
).remove_out_of_image().clip_out_of_image().bounding_boxes[0] | |
augmented_bbox = [_bbox.x1, _bbox.y1, _bbox.x2, _bbox.y2] | |
annotation['bbox'] = augmented_bbox | |
annotation['bbox_mode'] = BoxMode.XYXY_ABS | |
if 'segmentation' not in annotation: | |
return annotation | |
# transform the segmentation | |
segm = annotation['segmentation'] | |
# for simplicity handle polygons only at now | |
assert isinstance(segm, list), 'Unsppourted segmentation format' | |
polygons = [Polygon(np.asarray(p).reshape(-1, 2)) for p in segm] | |
_polygons = augmentation.augment_polygons( | |
PolygonsOnImage(polygons, image_shape) | |
).remove_out_of_image().polygons | |
annotation['segmentation'] = [p.coords.reshape(-1) for p in _polygons] | |
return annotation | |
class AugTrainer(DefaultTrainer): | |
@classmethod | |
def build_evaluator(cls, cfg, dataset_name): | |
output_folder = os.path.join(cfg.OUTPUT_DIR, 'inference') | |
return COCOEvaluator(dataset_name, output_dir=output_folder) | |
@classmethod | |
def build_train_loader(cls, cfg): | |
augmentation = iaa.Sequential([ | |
iaa.Fliplr(0.2), | |
iaa.Sometimes( | |
0.5, | |
iaa.Sequential([ | |
iaa.Sometimes( | |
0.5, | |
iaa.AddToHueAndSaturation( | |
(-20, 20), | |
per_channel=True, | |
), | |
iaa.Grayscale(alpha=(0, .2)) | |
), | |
iaa.Sometimes( | |
0.5, | |
iaa.WithBrightnessChannels(iaa.Add((-50, 50))), | |
iaa.EdgeDetect(alpha=(0, 0.3)), | |
), | |
]) | |
) | |
]) | |
mapper = Mapper(augmentation, is_train=True) | |
return build_detection_train_loader(cfg, mapper=mapper) | |
@classmethod | |
def build_test_loader(cls, cfg, dataset_name): | |
mapper = Mapper(iaa.Noop(), is_train=False) | |
return build_detection_test_loader(cfg, dataset_name, mapper=mapper) | |
def train(config_file, output_dir, image_dir, extra_aug=True): | |
cfg = get_cfg() | |
cfg.merge_from_file(config_file) | |
cfg.OUTPUT_DIR = output_dir | |
cfg.freeze() | |
trainer = AugTrainer(cfg) | |
ia_random.seed(42) | |
return trainer.train() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Train and evaluate based on config file.') | |
parser.add_argument('config', help='Specify config file', metavar='CONFIG_FILE') | |
parser.add_argument('images', help='Image root directroy', default=None) | |
parser.add_argument('-o', '--output', help='Output directory', default=None) | |
parser.add_argument('--gpu-per-machine', help='GPU number per machine', default=1) | |
# example: run the default mask rcnn training for coco datasets while images stored in /mnt/images | |
# python train_with_imgaug.py ./configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml /mnt/images | |
args = parser.parse_args() | |
config = str(pathlib.Path(args.config).resolve().absolute()) | |
output = pathlib.Path(args.output or './output/' + datetime.now().strftime('%Y%m%dT%H%M')) | |
output.mkdir(parents=True, exist_ok=True) | |
output = str(output.resolve().absolute()) | |
images = str(pathlib.Path(args.images).resolve().absolute()) | |
setup_logger() | |
launch(train, args.gpu_per_machine, dist_url='auto', args=(config, output, images)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment