Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save grigorisg9gr/3400731313ccb6f10d02 to your computer and use it in GitHub Desktop.
Save grigorisg9gr/3400731313ccb6f10d02 to your computer and use it in GitHub Desktop.
Train and fit a person-specific patch Active Appearance Model (AAM) for fitting landmarks.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train a person specific patch AAM. \n",
"\n",
"For more effective model, images of public datasets \n",
"are additionally loaded before training, in order to ensure a greater variety (and thus \n",
"fitting power) in the model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Uses the Menpo project (menpo, menpofit, menpowidgets). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from os import listdir\n",
"from os.path import join, isdir\n",
"import menpo.io as mio \n",
"import sys \n",
"import numpy as np\n",
"from glob import glob\n",
"from menpofit.visualize import print_progress\n",
"from menpo.landmark import face_ibug_68_to_face_ibug_49 as f_49\n",
"import random\n",
"\n",
"# when run in a terminal the DISPLAY is not defined.\n",
"try: \n",
" %matplotlib inline\n",
"except:\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load images from public datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def load_images_dataset(p_d, crop=0.1, max_im=400, nr_p=49):\n",
" \"\"\"\n",
" Loads the images from a dataset given a path. In that path the landmark files are assumed to be as well.\n",
" Each landmark file should have the same name as the image, e.g. for image 0001.png -> 0001.pts.\n",
" :param p_d: (string) Path of the images/landmark files.\n",
" :param crop: (float, optional) Amount of cropping around landmarks.\n",
" :param max_im: (int, optional) Maximum number of images to be loaded\n",
" :param nr_p: (int, optional) Number of landmark points. Should be either 68 or 49.\n",
" :return: (list) All the images loaded\n",
" \"\"\"\n",
" assert(isdir(p_d))\n",
" images = []\n",
" for im in print_progress(list(mio.import_images(p_d, max_images=max_im))):\n",
" # convert to greyscale\n",
" if im.n_channels == 3:\n",
" im = im.as_greyscale()\n",
" # crop around the landmarks \n",
" im = im.crop_to_landmarks_proportion(crop)\n",
" if im.landmarks['PTS'].lms.n_points != nr_p:\n",
" # then only the method of conversion from 68 to 49 supported in this version\n",
" assert(nr_p == 49 and im.landmarks['PTS'].lms.n_points == 68)\n",
" im.landmarks['PTS'] = f_49(im.landmarks['PTS'])\n",
" # append the image loaded to the list of images\n",
" images.append(im)\n",
" return images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# load images from public datasets\n",
"p_bd = '/vol/atlas/databases/'\n",
"assert(isdir(p_bd))\n",
"\n",
"folds = ['ibug', 'afw', '300w', 'helen/trainset', 'helen/testset', 'lfpw/trainset', 'lfpw/testset']\n",
"folds = ['ibug', '300w', 'lfpw/testset']\n",
"db_images = []\n",
"for fold in folds:\n",
" print(fold)\n",
" im1 = load_images_dataset(join(p_bd, fold, ''), max_im=200)\n",
" assert(len(im1) > 5)\n",
" db_images += im1\n",
"\n",
"print('The total length of images loaded is {}.'.format(len(db_images)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# optionally visualise the images with the landmarks\n",
"from menpowidgets import visualize_images\n",
"visualize_images(db_images)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define the paths"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# path where the clips are.\n",
"# The structure assumed is: \n",
"# -> path_base \n",
"# |-> 'frames'\n",
"# |-> [name_of_clip]\n",
"# |-> [frame_name].[extension]\n",
"# |-> ...\n",
"# |-> 'init'\n",
"# |-> [name_of_clip]\n",
"# |-> [frame_name].pts\n",
" \n",
"p_b = '/vol/atlas/homes/grigoris/videos_external/jie_2_2016/'\n",
"assert(isdir(p_b))\n",
"# rename below to choose the clip that will be used\n",
"clip = listdir(join(p_b, 'frames', ''))[0]\n",
"print('The clip chosen is: \\'{}\\''.format(clip))\n",
"\n",
"# all the paths below should exist.\n",
"p_fr = join(p_b, 'frames', clip, '')\n",
"# path of the initial landmarks (p_ln), used for loading only \n",
"# the landmarks that are well-fit for training the person-specific model: \n",
"p_ln = join(p_b, 'init', clip, '')\n",
"# path of the initial tracked results (all that we wish to be re-fit):\n",
"p_init = join(p_b, 'init', clip, '')\n",
"# path that the newly fit landmarks will be exported:\n",
"p_exp = join(p_b, 'init2', clip, '')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load images from the clip"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def load_images_clip(p_d, p_ln, crop=0.1, max_im=400, nr_p=49):\n",
" \"\"\"\n",
" Loads the images from a clip given the path of the frames and the landmarks. \n",
" Similar to load_images_dataset() above.\n",
" Each landmark file should have the same name as the image, e.g. for image 0001.png -> 0001.pts.\n",
" :param p_d: (string) Path of the images files.\n",
" :param p_ln: (string) Path of the landmark files.\n",
" :param crop: (float, optional) Amount of cropping around landmarks.\n",
" :param max_im: (int, optional) Maximum number of images to be loaded\n",
" :param nr_p: (int, optional) Number of landmark points. Should be either 68 or 49.\n",
" :return: (list) All the images loaded\n",
" \"\"\"\n",
" assert(isdir(p_d) and isdir(p_ln))\n",
" images = []\n",
" for ln in print_progress(list(mio.import_landmark_files(p_ln, max_landmarks=max_im))):\n",
" # search for the image in the p_fr path\n",
" ims = glob(p_fr + ln.path.stem + '*')\n",
" assert(len(ims) >= 1) # sanity check that the image exists\n",
" im = mio.import_image(ims[0])\n",
" # convert to greyscale\n",
" if im.n_channels == 3:\n",
" im = im.as_greyscale()\n",
" # attach the landmarks to the image\n",
" im.landmarks['PTS'] = ln\n",
" # crop around the landmarks \n",
" im = im.crop_to_landmarks_proportion(crop)\n",
" if im.landmarks['PTS'].lms.n_points != nr_p:\n",
" # then only the method of conversion from 68 to 49 supported in this version\n",
" assert(nr_p == 49 and im.landmarks['PTS'].lms.n_points == 68)\n",
" im.landmarks['PTS'] = f_49(im.landmarks['PTS'])\n",
" # append the image loaded to the list of images\n",
" images.append(im)\n",
" return images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# load the actual images from the clip\n",
"clip_images = load_images_clip(p_fr, p_ln, max_im=250)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train the patch AAM and define the fitter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# create a new list with all the images\n",
"training_images = db_images + clip_images\n",
"random.Random(9).shuffle(training_images)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# you can choose the features you want, but fast_dsift perform quite well and are quite fast.\n",
"from menpo.feature import fast_dsift\n",
"# imports for GN-DPM builder/fitter:\n",
"from menpofit.aam import PatchAAM, LucasKanadeAAMFitter\n",
"\n",
"## OPTIONS\n",
"features = fast_dsift\n",
"patch_shape = (18, 18)\n",
"crop = 0.2\n",
"diagonal = 180\n",
"# if the initial fitting is quite precise, then you might consider increasing \n",
"# the params below, so that it gets projected in a higher-dimensional space.\n",
"n_shape=[4, 10]\n",
"n_appearance=[60, 150]\n",
"fitter = []\n",
"\n",
"# careful with the RAM, it takes quite some memory to train.\n",
"aam = PatchAAM(training_images, verbose=True, holistic_features=features, patch_shape=patch_shape,\n",
" diagonal=diagonal, scales=(.5, 1))\n",
"\n",
"del training_images # delete the training images!\n",
"sampling_step = 2\n",
"sampling_mask = np.zeros(patch_shape, dtype=np.bool) # create the sampling mask\n",
"sampling_mask[::sampling_step, ::sampling_step] = True\n",
"fitter = LucasKanadeAAMFitter(aam, n_shape=n_shape, n_appearance=n_appearance, \n",
" sampling=sampling_mask)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# visualise the trained AAM\n",
"try: # in case it's in a terminal\n",
" %matplotlib inline \n",
" from menpowidgets import visualize_patch_aam\n",
" visualize_patch_aam(aam, [4, 10], [60, 140])\n",
"except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# # you can safely delete the aam variable if memory is an issue.\n",
"# delete aam \n",
"# del aam"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fit the trained AAM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def fit_frame(frame_name, p_fr, p_ln, p_exp, crop=0.1):\n",
" \"\"\"\n",
" Fits the model learnt in a frame with fit from shape. \n",
" Assumes a global var of fitter that is the model.\n",
" :param frame_name: (string) Name of the frame, e.g. '0001.png'.\n",
" :param p_fr: (string) Name of the landmarks path.\n",
" :param p_ln: (string) Name of the init landmarks path.\n",
" :param p_exp: (string) Name that the landmarks will be exported into.\n",
" :param crop: (float, optional) Amount of cropping around landmarks.\n",
" :return: \n",
" \"\"\"\n",
" # get rid of the extension of the image, e.g. '001.png' -> name = '001'.\n",
" name = frame_name[:frame_name.rfind('.')]\n",
" p0 = p_ln + name + '*.pts'\n",
" res = glob(p0)\n",
" if len(res) == 0: \n",
" return\n",
" # load the image and the initial landmarks\n",
" im = mio.import_image(p_fr + frame_name)\n",
" ln = mio.import_landmark_file(res[0])\n",
" # attach the landmarks to the image\n",
" im.landmarks['PTS'] = ln \n",
" if im.n_channels == 3:\n",
" im = im.as_greyscale()\n",
" # fit the image, the max_iter might need to be modified for \n",
" # faster fitting depending on the case of the clip.\n",
" fr = fitter.fit_from_shape(im, im.landmarks['PTS'].lms, crop_image=crop,\n",
" max_iter=14)\n",
" # export the resulting landmark\n",
" p_wr = p_exp + im.path.stem + '.pts'\n",
" mio.export_landmark_file(fr.fitted_image.landmarks['final'], p_wr, overwrite=True)\n",
" return fr.fitted_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# get the sorted list of frames\n",
"l_frames = sorted(listdir(p_fr))\n",
"\n",
"# fit all the frames iteratively\n",
"for frame_name in print_progress(l_frames):\n",
" c = fit_frame(frame_name, p_fr, p_init, p_exp)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import the new landmarks and visualise the results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"images = []\n",
"for i in mio.import_images(p_fr, max_images=100): \n",
" try:\n",
" pts = mio.import_landmark_file(p_init + i.path.stem + '.pts')\n",
" pts_fin = mio.import_landmark_file(p_exp + i.path.stem + '.pts')\n",
" except ValueError:\n",
" continue\n",
" i.landmarks['init'] = pts\n",
" i.landmarks['final'] = pts_fin\n",
" i = i.crop_to_landmarks_proportion(0.2, 'init')\n",
" images.append(i)\n",
"assert(len(images) > 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"try:\n",
" %matplotlib inline\n",
" visualize_images(images)\n",
"except:\n",
" pass"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment