Created
February 2, 2016 20:09
-
-
Save grigorisg9gr/3400731313ccb6f10d02 to your computer and use it in GitHub Desktop.
Train and fit a person-specific patch Active Appearance Model (AAM) for fitting landmarks.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Train a person specific patch AAM. \n", | |
"\n", | |
"For more effective model, images of public datasets \n", | |
"are additionally loaded before training, in order to ensure a greater variety (and thus \n", | |
"fitting power) in the model." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Uses the Menpo project (menpo, menpofit, menpowidgets). " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from os import listdir\n", | |
"from os.path import join, isdir\n", | |
"import menpo.io as mio \n", | |
"import sys \n", | |
"import numpy as np\n", | |
"from glob import glob\n", | |
"from menpofit.visualize import print_progress\n", | |
"from menpo.landmark import face_ibug_68_to_face_ibug_49 as f_49\n", | |
"import random\n", | |
"\n", | |
"# when run in a terminal the DISPLAY is not defined.\n", | |
"try: \n", | |
" %matplotlib inline\n", | |
"except:\n", | |
" pass" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Load images from public datasets" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def load_images_dataset(p_d, crop=0.1, max_im=400, nr_p=49):\n", | |
" \"\"\"\n", | |
" Loads the images from a dataset given a path. In that path the landmark files are assumed to be as well.\n", | |
" Each landmark file should have the same name as the image, e.g. for image 0001.png -> 0001.pts.\n", | |
" :param p_d: (string) Path of the images/landmark files.\n", | |
" :param crop: (float, optional) Amount of cropping around landmarks.\n", | |
" :param max_im: (int, optional) Maximum number of images to be loaded\n", | |
" :param nr_p: (int, optional) Number of landmark points. Should be either 68 or 49.\n", | |
" :return: (list) All the images loaded\n", | |
" \"\"\"\n", | |
" assert(isdir(p_d))\n", | |
" images = []\n", | |
" for im in print_progress(list(mio.import_images(p_d, max_images=max_im))):\n", | |
" # convert to greyscale\n", | |
" if im.n_channels == 3:\n", | |
" im = im.as_greyscale()\n", | |
" # crop around the landmarks \n", | |
" im = im.crop_to_landmarks_proportion(crop)\n", | |
" if im.landmarks['PTS'].lms.n_points != nr_p:\n", | |
" # then only the method of conversion from 68 to 49 supported in this version\n", | |
" assert(nr_p == 49 and im.landmarks['PTS'].lms.n_points == 68)\n", | |
" im.landmarks['PTS'] = f_49(im.landmarks['PTS'])\n", | |
" # append the image loaded to the list of images\n", | |
" images.append(im)\n", | |
" return images" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# load images from public datasets\n", | |
"p_bd = '/vol/atlas/databases/'\n", | |
"assert(isdir(p_bd))\n", | |
"\n", | |
"folds = ['ibug', 'afw', '300w', 'helen/trainset', 'helen/testset', 'lfpw/trainset', 'lfpw/testset']\n", | |
"folds = ['ibug', '300w', 'lfpw/testset']\n", | |
"db_images = []\n", | |
"for fold in folds:\n", | |
" print(fold)\n", | |
" im1 = load_images_dataset(join(p_bd, fold, ''), max_im=200)\n", | |
" assert(len(im1) > 5)\n", | |
" db_images += im1\n", | |
"\n", | |
"print('The total length of images loaded is {}.'.format(len(db_images)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# optionally visualise the images with the landmarks\n", | |
"from menpowidgets import visualize_images\n", | |
"visualize_images(db_images)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Define the paths" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# path where the clips are.\n", | |
"# The structure assumed is: \n", | |
"# -> path_base \n", | |
"# |-> 'frames'\n", | |
"# |-> [name_of_clip]\n", | |
"# |-> [frame_name].[extension]\n", | |
"# |-> ...\n", | |
"# |-> 'init'\n", | |
"# |-> [name_of_clip]\n", | |
"# |-> [frame_name].pts\n", | |
" \n", | |
"p_b = '/vol/atlas/homes/grigoris/videos_external/jie_2_2016/'\n", | |
"assert(isdir(p_b))\n", | |
"# rename below to choose the clip that will be used\n", | |
"clip = listdir(join(p_b, 'frames', ''))[0]\n", | |
"print('The clip chosen is: \\'{}\\''.format(clip))\n", | |
"\n", | |
"# all the paths below should exist.\n", | |
"p_fr = join(p_b, 'frames', clip, '')\n", | |
"# path of the initial landmarks (p_ln), used for loading only \n", | |
"# the landmarks that are well-fit for training the person-specific model: \n", | |
"p_ln = join(p_b, 'init', clip, '')\n", | |
"# path of the initial tracked results (all that we wish to be re-fit):\n", | |
"p_init = join(p_b, 'init', clip, '')\n", | |
"# path that the newly fit landmarks will be exported:\n", | |
"p_exp = join(p_b, 'init2', clip, '')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Load images from the clip" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def load_images_clip(p_d, p_ln, crop=0.1, max_im=400, nr_p=49):\n", | |
" \"\"\"\n", | |
" Loads the images from a clip given the path of the frames and the landmarks. \n", | |
" Similar to load_images_dataset() above.\n", | |
" Each landmark file should have the same name as the image, e.g. for image 0001.png -> 0001.pts.\n", | |
" :param p_d: (string) Path of the images files.\n", | |
" :param p_ln: (string) Path of the landmark files.\n", | |
" :param crop: (float, optional) Amount of cropping around landmarks.\n", | |
" :param max_im: (int, optional) Maximum number of images to be loaded\n", | |
" :param nr_p: (int, optional) Number of landmark points. Should be either 68 or 49.\n", | |
" :return: (list) All the images loaded\n", | |
" \"\"\"\n", | |
" assert(isdir(p_d) and isdir(p_ln))\n", | |
" images = []\n", | |
" for ln in print_progress(list(mio.import_landmark_files(p_ln, max_landmarks=max_im))):\n", | |
" # search for the image in the p_fr path\n", | |
" ims = glob(p_fr + ln.path.stem + '*')\n", | |
" assert(len(ims) >= 1) # sanity check that the image exists\n", | |
" im = mio.import_image(ims[0])\n", | |
" # convert to greyscale\n", | |
" if im.n_channels == 3:\n", | |
" im = im.as_greyscale()\n", | |
" # attach the landmarks to the image\n", | |
" im.landmarks['PTS'] = ln\n", | |
" # crop around the landmarks \n", | |
" im = im.crop_to_landmarks_proportion(crop)\n", | |
" if im.landmarks['PTS'].lms.n_points != nr_p:\n", | |
" # then only the method of conversion from 68 to 49 supported in this version\n", | |
" assert(nr_p == 49 and im.landmarks['PTS'].lms.n_points == 68)\n", | |
" im.landmarks['PTS'] = f_49(im.landmarks['PTS'])\n", | |
" # append the image loaded to the list of images\n", | |
" images.append(im)\n", | |
" return images" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# load the actual images from the clip\n", | |
"clip_images = load_images_clip(p_fr, p_ln, max_im=250)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Train the patch AAM and define the fitter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# create a new list with all the images\n", | |
"training_images = db_images + clip_images\n", | |
"random.Random(9).shuffle(training_images)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# you can choose the features you want, but fast_dsift perform quite well and are quite fast.\n", | |
"from menpo.feature import fast_dsift\n", | |
"# imports for GN-DPM builder/fitter:\n", | |
"from menpofit.aam import PatchAAM, LucasKanadeAAMFitter\n", | |
"\n", | |
"## OPTIONS\n", | |
"features = fast_dsift\n", | |
"patch_shape = (18, 18)\n", | |
"crop = 0.2\n", | |
"diagonal = 180\n", | |
"# if the initial fitting is quite precise, then you might consider increasing \n", | |
"# the params below, so that it gets projected in a higher-dimensional space.\n", | |
"n_shape=[4, 10]\n", | |
"n_appearance=[60, 150]\n", | |
"fitter = []\n", | |
"\n", | |
"# careful with the RAM, it takes quite some memory to train.\n", | |
"aam = PatchAAM(training_images, verbose=True, holistic_features=features, patch_shape=patch_shape,\n", | |
" diagonal=diagonal, scales=(.5, 1))\n", | |
"\n", | |
"del training_images # delete the training images!\n", | |
"sampling_step = 2\n", | |
"sampling_mask = np.zeros(patch_shape, dtype=np.bool) # create the sampling mask\n", | |
"sampling_mask[::sampling_step, ::sampling_step] = True\n", | |
"fitter = LucasKanadeAAMFitter(aam, n_shape=n_shape, n_appearance=n_appearance, \n", | |
" sampling=sampling_mask)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# visualise the trained AAM\n", | |
"try: # in case it's in a terminal\n", | |
" %matplotlib inline \n", | |
" from menpowidgets import visualize_patch_aam\n", | |
" visualize_patch_aam(aam, [4, 10], [60, 140])\n", | |
"except:\n", | |
" pass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# # you can safely delete the aam variable if memory is an issue.\n", | |
"# delete aam \n", | |
"# del aam" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Fit the trained AAM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def fit_frame(frame_name, p_fr, p_ln, p_exp, crop=0.1):\n", | |
" \"\"\"\n", | |
" Fits the model learnt in a frame with fit from shape. \n", | |
" Assumes a global var of fitter that is the model.\n", | |
" :param frame_name: (string) Name of the frame, e.g. '0001.png'.\n", | |
" :param p_fr: (string) Name of the landmarks path.\n", | |
" :param p_ln: (string) Name of the init landmarks path.\n", | |
" :param p_exp: (string) Name that the landmarks will be exported into.\n", | |
" :param crop: (float, optional) Amount of cropping around landmarks.\n", | |
" :return: \n", | |
" \"\"\"\n", | |
" # get rid of the extension of the image, e.g. '001.png' -> name = '001'.\n", | |
" name = frame_name[:frame_name.rfind('.')]\n", | |
" p0 = p_ln + name + '*.pts'\n", | |
" res = glob(p0)\n", | |
" if len(res) == 0: \n", | |
" return\n", | |
" # load the image and the initial landmarks\n", | |
" im = mio.import_image(p_fr + frame_name)\n", | |
" ln = mio.import_landmark_file(res[0])\n", | |
" # attach the landmarks to the image\n", | |
" im.landmarks['PTS'] = ln \n", | |
" if im.n_channels == 3:\n", | |
" im = im.as_greyscale()\n", | |
" # fit the image, the max_iter might need to be modified for \n", | |
" # faster fitting depending on the case of the clip.\n", | |
" fr = fitter.fit_from_shape(im, im.landmarks['PTS'].lms, crop_image=crop,\n", | |
" max_iter=14)\n", | |
" # export the resulting landmark\n", | |
" p_wr = p_exp + im.path.stem + '.pts'\n", | |
" mio.export_landmark_file(fr.fitted_image.landmarks['final'], p_wr, overwrite=True)\n", | |
" return fr.fitted_image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# get the sorted list of frames\n", | |
"l_frames = sorted(listdir(p_fr))\n", | |
"\n", | |
"# fit all the frames iteratively\n", | |
"for frame_name in print_progress(l_frames):\n", | |
" c = fit_frame(frame_name, p_fr, p_init, p_exp)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Import the new landmarks and visualise the results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"images = []\n", | |
"for i in mio.import_images(p_fr, max_images=100): \n", | |
" try:\n", | |
" pts = mio.import_landmark_file(p_init + i.path.stem + '.pts')\n", | |
" pts_fin = mio.import_landmark_file(p_exp + i.path.stem + '.pts')\n", | |
" except ValueError:\n", | |
" continue\n", | |
" i.landmarks['init'] = pts\n", | |
" i.landmarks['final'] = pts_fin\n", | |
" i = i.crop_to_landmarks_proportion(0.2, 'init')\n", | |
" images.append(i)\n", | |
"assert(len(images) > 0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"try:\n", | |
" %matplotlib inline\n", | |
" visualize_images(images)\n", | |
"except:\n", | |
" pass" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment