Last active
February 24, 2025 19:15
-
-
Save hugke729/78655b82b885cde79e270f1c30da0b5f to your computer and use it in GitHub Desktop.
Simplify export of matplotlib figures when both raster and vector components are desired in output
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A function to rasterize components of a matplotlib figure while keeping | |
# axes, labels, etc as vector components | |
# https://brushingupscience.wordpress.com/2017/05/09/vector-and-raster-in-one-with-matplotlib/ | |
from inspect import getmembers, isclass | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def rasterize_and_save(fname, rasterize_list=None, fig=None, dpi=None, | |
savefig_kw={}): | |
"""Save a figure with raster and vector components | |
This function lets you specify which objects to rasterize at the export | |
stage, rather than within each plotting call. Rasterizing certain | |
components of a complex figure can significantly reduce file size. | |
Inputs | |
------ | |
fname : str | |
Output filename with extension | |
rasterize_list : list (or object) | |
List of objects to rasterize (or a single object to rasterize) | |
fig : matplotlib figure object | |
Defaults to current figure | |
dpi : int | |
Resolution (dots per inch) for rasterizing | |
savefig_kw : dict | |
Extra keywords to pass to matplotlib.pyplot.savefig | |
If rasterize_list is not specified, then all contour, pcolor, and | |
collects objects (e.g., ``scatter, fill_between`` etc) will be | |
rasterized | |
Note: does not work correctly with round=True in Basemap | |
Example | |
------- | |
Rasterize the contour, pcolor, and scatter plots, but not the line | |
>>> import matplotlib.pyplot as plt | |
>>> from numpy.random import random | |
>>> X, Y, Z = random((9, 9)), random((9, 9)), random((9, 9)) | |
>>> fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2) | |
>>> cax1 = ax1.contourf(Z) | |
>>> cax2 = ax2.scatter(X, Y, s=Z) | |
>>> cax3 = ax3.pcolormesh(Z) | |
>>> cax4 = ax4.plot(Z[:, 0]) | |
>>> rasterize_list = [cax1, cax2, cax3] | |
>>> rasterize_and_save('out.svg', rasterize_list, fig=fig, dpi=300) | |
""" | |
# Behave like pyplot and act on current figure if no figure is specified | |
fig = plt.gcf() if fig is None else fig | |
# Need to set_rasterization_zorder in order for rasterizing to work | |
zorder = -5 # Somewhat arbitrary, just ensuring less than 0 | |
if rasterize_list is None: | |
# Have a guess at stuff that should be rasterised | |
types_to_raster = ['QuadMesh', 'Contour', 'collections'] | |
rasterize_list = [] | |
print(""" | |
No rasterize_list specified, so the following objects will | |
be rasterized: """) | |
# Get all axes, and then get objects within axes | |
for ax in fig.get_axes(): | |
for item in ax.get_children(): | |
if any(x in str(item) for x in types_to_raster): | |
rasterize_list.append(item) | |
print('\n'.join([str(x) for x in rasterize_list])) | |
else: | |
# Allow rasterize_list to be input as an object to rasterize | |
if type(rasterize_list) != list: | |
rasterize_list = [rasterize_list] | |
for item in rasterize_list: | |
# Whether or not plot is a contour plot is important | |
is_contour = (isinstance(item, matplotlib.contour.QuadContourSet) or | |
isinstance(item, matplotlib.tri.TriContourSet)) | |
# Whether or not collection of lines | |
# This is commented as we seldom want to rasterize lines | |
# is_lines = isinstance(item, matplotlib.collections.LineCollection) | |
# Whether or not current item is list of patches | |
all_patch_types = tuple( | |
x[1] for x in getmembers(matplotlib.patches, isclass)) | |
try: | |
is_patch_list = isinstance(item[0], all_patch_types) | |
except TypeError: | |
is_patch_list = False | |
# Convert to rasterized mode and then change zorder properties | |
if is_contour: | |
curr_ax = item.ax.axes | |
curr_ax.set_rasterization_zorder(zorder) | |
# For contour plots, need to set each part of the contour | |
# collection individually | |
for contour_level in item.collections: | |
contour_level.set_zorder(zorder - 1) | |
contour_level.set_rasterized(True) | |
elif is_patch_list: | |
# For list of patches, need to set zorder for each patch | |
for patch in item: | |
curr_ax = patch.axes | |
curr_ax.set_rasterization_zorder(zorder) | |
patch.set_zorder(zorder - 1) | |
patch.set_rasterized(True) | |
else: | |
# For all other objects, we can just do it all at once | |
curr_ax = item.axes | |
curr_ax.set_rasterization_zorder(zorder) | |
item.set_rasterized(True) | |
item.set_zorder(zorder - 1) | |
# dpi is a savefig keyword argument, but treat it as special since it is | |
# important to this function | |
if dpi is not None: | |
savefig_kw['dpi'] = dpi | |
# Save resulting figure | |
fig.savefig(fname, **savefig_kw) | |
# Test rasterize_and_save | |
if __name__ is '__main__': | |
from numpy.random import random | |
X, Y, Z = random((9, 9)), random((9, 9)), random((9, 9)) | |
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2) | |
cax1 = ax1.contourf(Z) | |
cax2 = ax2.scatter(X, Y, s=Z) | |
cax3 = ax3.pcolormesh(Z) | |
cax4 = ax4.plot(Z[:, 0]) | |
rasterize_list = [cax1, cax2, cax3] | |
rasterize_and_save('out.svg', rasterize_list, fig=fig, dpi=300) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment