import os
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
# Attempt to import imageio for video generatioj
try:
import imageio
IMEGEIO_AVAILABLE = True
except ImportError:
IMAGEIO_AVAILABLE = False
from cycler import cycler
COLORS_PYTHON = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
COLORS_MATLAB = [
"#0072BD",
"#D95319",
"#EDB120",
"#7E2F8E",
"#77AC30",
"#4DBEEE",
"#A2142F",
]
[docs]
def set_plot_options(
fontsize=14,
grid=True,
major_ticks=True,
minor_ticks=True,
margin=0.05,
color_order="matlab",
linewidth=1.25,
):
"""
Set options for creating publication-quality figures using Matplotlib.
This function updates the internal Matplotlib settings to better align with standards for publication-quality figures.
Features include improved font selections, tick marks, grid appearance, and color selections.
Parameters
----------
fontsize : int, optional
Font size for text elements in the plot. Default is 13.
grid : bool, optional
Whether to show grid lines on the plot. Default is True.
major_ticks : bool, optional
Whether to show major ticks. Default is True.
minor_ticks : bool, optional
Whether to show minor ticks. Default is True.
margin : float, optional
Margin size for axes. Default is 0.05.
color_order : str, optional
Color order to be used for plot lines. Options include "python" and "matlab". Default is "matlab".
"""
if isinstance(color_order, str):
if color_order.lower() == "default":
color_order = COLORS_PYTHON
elif color_order.lower() == "matlab":
color_order = COLORS_MATLAB
# Define dictionary of custom settings
rcParams = {
"text.usetex": False,
"font.size": fontsize,
"font.style": "normal",
"font.family": "serif", # 'serif', 'sans-serif', 'cursive', 'fantasy', 'monospace'
"font.serif": ["Times New Roman"], # ['times new roman', 'cmr10']
"mathtext.fontset": "stix", # ["stix", 'cm']
"axes.edgecolor": "black",
"axes.linewidth": 1.25,
"axes.titlesize": fontsize,
"axes.titleweight": "normal",
"axes.titlepad": fontsize * 1.4,
"axes.labelsize": fontsize,
"axes.labelweight": "normal",
"axes.labelpad": fontsize,
"axes.xmargin": margin,
"axes.ymargin": margin,
"axes.zmargin": margin,
"axes.grid": grid,
"axes.grid.axis": "both",
"axes.grid.which": "major",
"axes.prop_cycle": cycler(color=color_order),
"grid.alpha": 1,
"grid.color": "#808080", # Grey equivalent to 50% transparency
"grid.linestyle": "-",
"grid.linewidth": 0.5,
"legend.borderaxespad": 1,
"legend.borderpad": 0.6,
"legend.edgecolor": "black",
"legend.facecolor": "white",
"legend.labelcolor": "black",
"legend.labelspacing": 0.3,
"legend.fancybox": True,
"legend.fontsize": fontsize - 2,
"legend.framealpha": 1.00,
"legend.handleheight": 0.7,
"legend.handlelength": 1.25,
"legend.handletextpad": 0.8,
"legend.markerscale": 1.0,
"legend.numpoints": 1,
"lines.linewidth": linewidth,
"lines.markersize": 4.5,
"lines.markeredgewidth": linewidth,
"lines.markerfacecolor": "white",
"xtick.direction": "in",
"xtick.labelsize": fontsize - 1,
"xtick.bottom": major_ticks,
"xtick.top": major_ticks,
"xtick.major.size": 6,
"xtick.major.width": 1.25,
"xtick.minor.size": 3,
"xtick.minor.width": 0.75,
"xtick.minor.visible": minor_ticks,
"ytick.direction": "in",
"ytick.labelsize": fontsize - 1,
"ytick.left": major_ticks,
"ytick.right": major_ticks,
"ytick.major.size": 6,
"ytick.major.width": 1.25,
"ytick.minor.size": 3,
"ytick.minor.width": 0.75,
"ytick.minor.visible": minor_ticks,
"savefig.dpi": 500,
}
# Update the internal Matplotlib settings dictionary
mpl.rcParams.update(rcParams)
[docs]
def print_installed_fonts():
"""
Print the list of fonts installed on the system.
This function identifies and prints all available fonts for use in Matplotlib.
"""
fonts = mpl.font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
for font in sorted(fonts):
print(font)
[docs]
def print_rc_parameters(filename=None):
"""
Print the current rcParams used by Matplotlib or write to file if provided.
This function provides a quick overview of the active configuration parameters within Matplotlib.
"""
params = mpl.rcParams
for key, value in params.items():
print(f"{key}: {value}")
if filename:
with open(filename, "w") as file:
for key, value in params.items():
file.write(f"{key}: {value}\n")
[docs]
def create_gif(image_folder, output_file, duration=0.5):
"""
Create a GIF from a series of images.
Parameters
----------
image_folder : str
The path to the folder containing the images.
output_file : str
The path and filename of the output GIF.
duration : float, optional
Duration of each frame in the GIF, by default 0.5 seconds.
"""
if not IMEGEIO_AVAILABLE:
raise ImportError(
"imageio is not installed. Please run `pip install imageio-ffmpeg` to run this function."
)
images = []
for filename in sorted(os.listdir(image_folder)):
if filename.endswith(".png"):
file_path = os.path.join(image_folder, filename)
images.append(imageio.imread(file_path))
imageio.mimsave(output_file, images, duration=duration)
[docs]
def create_mp4(image_folder, output_file, fps=10):
"""
Create an MP4 video from a series of images.
Parameters
----------
image_folder : str
The path to the folder containing the images.
output_file : str
The path and filename of the output MP4 video.
fps : int, optional
Frames per second in the output video, by default 10.
"""
if not IMEGEIO_AVAILABLE:
raise ImportError(
"imageio is not installed. Please run `pip install imageio-ffmpeg` to run this function."
)
with imageio.get_writer(output_file, fps=fps) as writer:
for filename in sorted(os.listdir(image_folder)):
if filename.endswith(".png"):
file_path = os.path.join(image_folder, filename)
image = imageio.imread(file_path)
writer.append_data(image)
[docs]
def scale_graphics_x(fig, scale, mode="multiply"):
"""Scale x-coordinates of graphics objects"""
for ax in fig.get_axes():
# Scaling lines
for line in ax.get_lines():
xdata, ydata = line.get_data()
if mode == "multiply":
line.set_xdata(xdata * scale)
elif mode == "add":
line.set_xdata(xdata + scale)
# Scaling patches (like rectangles)
for patch in ax.patches:
if mode == "multiply":
xy = patch.get_xy()
xy[:, 0] = xy[:, 0] * scale
patch.set_xy(xy)
elif mode == "add":
xy = patch.get_xy()
xy[:, 0] = xy[:, 0] + scale
patch.set_xy(xy)
# Scaling contour plots
for collection in ax.collections:
for path in collection.get_paths():
if mode == "multiply":
path.vertices[:, 0] *= scale
elif mode == "add":
path.vertices[:, 0] += scale
[docs]
def scale_graphics_y(fig, scale, mode="multiply"):
"""Scale y-coordinates of graphics objects"""
for ax in fig.get_axes():
# Scaling lines
for line in ax.get_lines():
xdata, ydata = line.get_data()
if mode == "multiply":
line.set_ydata(ydata * scale)
elif mode == "add":
line.set_ydata(ydata + scale)
# Scaling patches (like rectangles)
for patch in ax.patches:
if mode == "multiply":
xy = patch.get_xy()
xy[:, 1] = xy[:, 1] * scale
patch.set_xy(xy)
elif mode == "add":
xy = patch.get_xy()
xy[:, 1] = xy[:, 1] + scale
patch.set_xy(xy)
# Scaling contour plots
for collection in ax.collections:
for path in collection.get_paths():
if mode == "multiply":
path.vertices[:, 1] *= scale
elif mode == "add":
path.vertices[:, 1] += scale