Last active
March 31, 2023 10:08
-
-
Save dgobbi/bddbbabb1a9c86e8d8373752859dac9f to your computer and use it in GitHub Desktop.
Generate many nifti files corresponding to transformations of one nifti file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Program vtk_augment_nifti.py | |
Sep 25, 2018 | |
David Gobbi | |
[email protected] | |
""" | |
import vtk | |
import sys | |
import os.path | |
import argparse | |
import math | |
brief = "Apply multiple transformations to a NIfTI file." | |
helptext = """ | |
This program will take in image as input and will generate multiple | |
output files, each with a different transformation applied. | |
You must supply an existing output directory for the files to be | |
written to. Each output file will be named "<file>_NNNN.nii.gz" | |
where <file> is the name of the input, and "NNNN" is four digits. | |
You can specifify the number of distinct transformations as | |
follows: | |
-R N (will produce 3*NxNxN transforms, for the 3 angles) | |
-S M (will produce M transforms, since only 1 scale param is used) | |
-T L (will produce LxL transforms, since 2 translation params are used) | |
The total number of outputs is therefore 3*N*N*N*M*L*L. | |
Translations are only done in the row and column directions, not the slice | |
direction. The largest applied translation will be 15% of the image size. | |
The largest rotation that will be applied is 45 degrees. There will not | |
be any outputs that correspond to a rotation of zero degrees, unless you | |
specify N=0. | |
The scale factor will be a minimum of 1.0 and a maximum of 1.5. No scales | |
less than 1.0 are used since shrinking an image can lead to aliasing | |
artifacts. | |
Note that there are to ways to specify input images: -i and -l. | |
For "-i", the image is assumed to be greyscale and linear interpolation | |
is used. For "-l", the image is assumed to be binary or indexed, and | |
nearest-neighbor interpolation is used. | |
""" | |
def linscale(scale_range, n): | |
"""Generate a linear scale of 'n' values within the given range. | |
""" | |
if n <= 1: | |
return [0.5*(scale_range[0] + scale_range[1])] | |
d = float(n - 1) | |
return [(n-i-1)/d*scale_range[0] + i/d*scale_range[1] for i in range(n)] | |
def combine(*scales): | |
"""Create a grid from the input scales (inefficiently). | |
""" | |
if len(scales) == 0: | |
return [] | |
result = [ [y] for y in scales[0] ] | |
for scale in scales[1:]: | |
newres = [] | |
for y in scale: | |
for x in result: | |
newres.append(x + [y]) | |
result = newres | |
return result | |
def generate_rotations(angle_range, n): | |
"""Generate approximately regularly spaced rotations by using | |
cube geometry instead of sphere geometry. | |
""" | |
if n <= 0: | |
return [ [ 0.0, 1.0, 0.0, 0.0 ] ] | |
# use half a cube to generate axes of rotation | |
# (because a cube is the easiest platonic solid to work from) | |
# here are the angles (except the first, assumed to be zero) | |
scale0 = linscale(angle_range, n + 1)[1:] | |
# the scale along one edge of a cube | |
scale1 = linscale([-1.0, 1.0], n) | |
# rotations corresponding to axes intersecting 3 of the cube faces | |
result = combine(scale0, [1.0], scale1, scale1) # cube face 1 | |
result += combine(scale0, scale1, [1.0], scale1) # cube face 2 | |
result += combine(scale0, scale1, scale1, [1.0]) # cube face 3 | |
# normalize the axis of rotation | |
for w in result: | |
r = math.sqrt(w[1]**2 + w[2]**2 + w[3]**2) | |
w[1] /= r | |
w[2] /= r | |
w[3] /= r | |
return result | |
def generate_scales(scale_range, n): | |
"""Use a linear scale for scale parameter. | |
""" | |
return linscale(scale_range, n) | |
def generate_translations(trans_ranges, n): | |
"""Translate in all three directions. | |
""" | |
scales = [linscale(rng, n) for rng in trans_ranges] | |
while len(scales) < 3: | |
scales.append([1.0]) | |
return combine(*scales) | |
def build_transform(rotation, scale, translation, center): | |
"""Build a vtkTransform from a set of parameters. | |
""" | |
transform = vtk.vtkTransform() | |
transform.PostMultiply() | |
transform.Translate([-x for x in center]) | |
transform.RotateWXYZ(*rotation) | |
transform.Scale(scale, scale, scale) | |
transform.Translate(center) | |
transform.Translate(translation) | |
return transform | |
def process_one_output(output_file, header, sform, qform, image, transform, | |
is_label): | |
"""Write one transformed image file. | |
""" | |
# slow, high-quality interpolator | |
#interpolator = vtk.vtkImageSincInterpolator() | |
#interpolator.SetWindowFunctionToBlackman() | |
# resample the image through a transform | |
reslice = vtk.vtkImageReslice() | |
reslice.SetNumberOfThreads(1) | |
reslice.SetInputData(image) | |
if not is_label: | |
reslice.SetInterpolationModeToLinear() | |
reslice.TransformInputSamplingOff() | |
reslice.SetResliceTransform(transform.GetInverse()) | |
reslice.Update() | |
# write the image with the same sform, qform, and header as the input | |
writer = vtk.vtkNIFTIImageWriter() | |
writer.SetInputData(reslice.GetOutput()) | |
writer.SetFileName(output_file) | |
writer.SetNIFTIHeader(header) | |
writer.SetSFormMatrix(sform) | |
writer.SetQFormMatrix(qform) | |
writer.Write() | |
def process_one_input(input_file, args, is_label): | |
"""Augment one input image file. | |
""" | |
# read the input file | |
reader = vtk.vtkNIFTIImageReader() | |
reader.SetFileName(input_file) | |
reader.Update() | |
# get the header, the sform, the qform, and the image data | |
header = reader.GetNIFTIHeader() | |
sform = reader.GetSFormMatrix() | |
qform = reader.GetQFormMatrix() | |
image = reader.GetOutput() | |
center = image.GetCenter() | |
spacing = image.GetSpacing() | |
shape = image.GetDimensions() | |
f = 0.15 # translation is a faction of the image size | |
trans_ranges = [ [-a*b*f, a*b*f] for a,b in zip(shape, spacing)] | |
rotation_range = [0.0, 45.0] | |
scale_range = [1.0, 1.5] | |
# only translate in x, y (assume the algorithms that use this data | |
# will be processing it slice-by-slice, not volumetrically) | |
trans_ranges = trans_ranges[0:-1] | |
# generate all of the transformational parameters | |
params = combine(generate_rotations(rotation_range, args.rotations), | |
generate_scales(scale_range, args.scales), | |
generate_translations(trans_ranges, args.translations)) | |
if not args.silent: | |
sys.stdout.write("Generating %d outputs per input " % len(params)) | |
sys.stdout.write("(each \".\" is one output).\n") | |
sys.stdout.flush() | |
# get prefix from input file | |
base = os.path.basename(input_file) | |
ext = '.nii.gz' | |
for e in ['.nii.gz', '.nii']: | |
if base.endswith(e): | |
base = base[0:-len(e)] | |
ext = e | |
break | |
# go though all the parameters | |
counter = 0 | |
for (rotation, scale, translation) in params: | |
if not args.silent: | |
sys.stdout.write(".") | |
sys.stdout.flush() | |
# build the transform from the parameters | |
transform = build_transform(rotation, scale, translation, center) | |
# create one output | |
counter += 1 | |
basename = base + ('_%04d' % counter) + ext | |
output_file = os.path.join(args.output, basename) | |
process_one_output(output_file, header, sform, qform, image, transform, | |
is_label) | |
if not args.silent: | |
if counter % 50 == 0: | |
sys.stdout.write("\n") | |
if not args.silent: | |
if counter % 50 != 0: | |
sys.stdout.write("\n") | |
sys.stdout.flush() | |
def main(argv): | |
"""The main program. | |
""" | |
# parse the command line | |
parser = argparse.ArgumentParser( | |
prog=argv[0], | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
description=brief, epilog=helptext) | |
parser.add_argument('-i', '--input', required=False, | |
help="Input greyscale image.") | |
parser.add_argument('-l', '--label', required=False, | |
help="Input label image.") | |
parser.add_argument('-o', '--output', required=True, | |
help="Output directory.") | |
parser.add_argument('-R', '--rotations', type=int, default=3, | |
help="Steps per rotation degree of freedom.") | |
parser.add_argument('-S', '--scales', type=int, default=2, | |
help="Steps per scale degree of freedom.") | |
parser.add_argument('-T', '--translations', type=int, default=1, | |
help="Steps per translation degree of freedom.") | |
parser.add_argument('-s', '--silent', action='count', | |
help="Do not print progress information.") | |
args = parser.parse_args(argv[1:]) | |
# validate some parameters | |
args.rotations = max(0, args.rotations) | |
args.scales = max(1, args.scales) | |
args.translations = max(1, args.translations) | |
# augment all the input files | |
if args.input: | |
process_one_input(args.input, args, is_label=False) | |
if args.label: | |
process_one_input(args.label, args, is_label=True) | |
if __name__ == '__main__': | |
main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment