Skip to content

Commit

Permalink
Merge pull request #287 from fooof-tools/grep
Browse files Browse the repository at this point in the history
[MNT] - Plot & report tweaks
  • Loading branch information
TomDonoghue authored Jul 21, 2023
2 parents f8ed771 + 70c20e8 commit 72acddc
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 28 deletions.
8 changes: 4 additions & 4 deletions fooof/core/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def save_report_fg(fg, file_name, file_path=None, add_settings=True):

# Initialize figure
_ = plt.figure(figsize=REPORT_FIGSIZE)
grid = gridspec.GridSpec(n_rows, 2, wspace=0.4, hspace=0.25, height_ratios=height_ratios)
grid = gridspec.GridSpec(n_rows, 2, wspace=0.35, hspace=0.25, height_ratios=height_ratios)

# First / top: text results
ax0 = plt.subplot(grid[0, :])
Expand All @@ -108,15 +108,15 @@ def save_report_fg(fg, file_name, file_path=None, add_settings=True):

# Aperiodic parameters plot
ax1 = plt.subplot(grid[1, 0])
plot_fg_ap(fg, ax1)
plot_fg_ap(fg, ax1, custom_styler=None)

# Goodness of fit plot
ax2 = plt.subplot(grid[1, 1])
plot_fg_gf(fg, ax2)
plot_fg_gf(fg, ax2, custom_styler=None)

# Peak center frequencies plot
ax3 = plt.subplot(grid[2, :])
plot_fg_peak_cens(fg, ax3)
plot_fg_peak_cens(fg, ax3, custom_styler=None)

# Third - Model settings
if add_settings:
Expand Down
8 changes: 4 additions & 4 deletions fooof/plts/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **plot_kwargs):
raise NoModelError("No model fit results are available, can not proceed.")

fig = plt.figure(figsize=plot_kwargs.pop('figsize', PLT_FIGSIZES['group']))
gs = gridspec.GridSpec(2, 2, wspace=0.4, hspace=0.25, height_ratios=[1, 1.2])
gs = gridspec.GridSpec(2, 2, wspace=0.35, hspace=0.35, height_ratios=[1, 1.2])

# Apply scatter kwargs to all subplots
scatter_kwargs = plot_kwargs
scatter_kwargs['all_axes'] = True

# Aperiodic parameters plot
ax0 = plt.subplot(gs[0, 0])
plot_fg_ap(fg, ax0, **scatter_kwargs)
plot_fg_ap(fg, ax0, **scatter_kwargs, custom_styler=None)

# Goodness of fit plot
ax1 = plt.subplot(gs[0, 1])
plot_fg_gf(fg, ax1, **scatter_kwargs)
plot_fg_gf(fg, ax1, **scatter_kwargs, custom_styler=None)

# Center frequencies plot
ax2 = plt.subplot(gs[1, :])
plot_fg_peak_cens(fg, ax2, **plot_kwargs)
plot_fg_peak_cens(fg, ax2, **plot_kwargs, custom_styler=None)


@savefig
Expand Down
10 changes: 5 additions & 5 deletions fooof/plts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
###################################################################################################

# Define default figure sizes
PLT_FIGSIZES = {'spectral' : (10, 8),
PLT_FIGSIZES = {'spectral' : (8.5, 6.5),
'params' : (7, 6),
'group' : (12, 10)}
'group' : (9, 7)}

# Define defaults for colors for plots, based on what is plotted
PLT_COLORS = {'data' : 'black',
Expand Down Expand Up @@ -45,8 +45,8 @@

## Define default values for plot aesthetics
# These are all custom style arguments
TITLE_FONTSIZE = 20
LABEL_SIZE = 16
TICK_LABELSIZE = 16
TITLE_FONTSIZE = 18
LABEL_SIZE = 14
TICK_LABELSIZE = 12
LEGEND_SIZE = 12
LEGEND_LOC = 'best'
17 changes: 12 additions & 5 deletions fooof/plts/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,14 @@ def apply_custom_style(ax, **kwargs):
ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)},
loc=kwargs.pop('legend_loc', LEGEND_LOC))

plt.tight_layout()
# Apply tight layout to the figure object, if matplotlib is new enough
# If available, `.set_layout_engine` should be equivalent to
# `plt.tight_layout()`, but seems to raise fewer warnings...
try:
fig = plt.gcf()
fig.set_layout_engine('tight')
except:
plt.tight_layout()


def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
Expand All @@ -192,10 +199,10 @@ def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
Each of these sub-functions can be replaced by passing in replacement callables.
"""

axis_styler(ax, **kwargs)
line_styler(ax, **kwargs)
collection_styler(ax, **kwargs)
custom_styler(ax, **kwargs)
axis_styler(ax, **kwargs) if axis_styler is not None else None
line_styler(ax, **kwargs) if line_styler is not None else None
collection_styler(ax, **kwargs) if collection_styler is not None else None
custom_styler(ax, **kwargs) if custom_styler is not None else None


def style_plot(func, *args, **kwargs):
Expand Down
21 changes: 11 additions & 10 deletions fooof/plts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.utils import check_ax, set_alpha
from fooof.plts.settings import TITLE_FONTSIZE, LABEL_SIZE, TICK_LABELSIZE

plt = safe_import('.pyplot', 'matplotlib')

Expand Down Expand Up @@ -46,14 +47,14 @@ def plot_scatter_1(data, label=None, title=None, x_val=0, ax=None):
ax.scatter(x_data, data, s=36, alpha=set_alpha(len(data)))

if label:
ax.set_ylabel(label, fontsize=16)
ax.set_ylabel(label, fontsize=LABEL_SIZE)
ax.set(xticks=[x_val], xticklabels=[label])

if title:
ax.set_title(title, fontsize=20)
ax.set_title(title, fontsize=TITLE_FONTSIZE)

ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=12)
ax.tick_params(axis='x', labelsize=TICK_LABELSIZE)
ax.tick_params(axis='y', labelsize=TICK_LABELSIZE)

ax.set_xlim([-0.5, 0.5])

Expand Down Expand Up @@ -89,12 +90,12 @@ def plot_scatter_2(data_0, label_0, data_1, label_1, title=None, ax=None):
plot_scatter_1(data_1, label_1, x_val=1, ax=ax1)

if title:
ax.set_title(title, fontsize=20)
ax.set_title(title, fontsize=TITLE_FONTSIZE)

ax.set(xlim=[-0.5, 1.5],
xticks=[0, 1],
xticklabels=[label_0, label_1])
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='x', labelsize=TICK_LABELSIZE)


@check_dependency(plt, 'matplotlib')
Expand All @@ -121,13 +122,13 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None):

ax.hist(data[~np.isnan(data)], n_bins, range=x_lims, alpha=0.8)

ax.set_xlabel(label, fontsize=16)
ax.set_ylabel('Count', fontsize=16)
ax.set_xlabel(label, fontsize=LABEL_SIZE)
ax.set_ylabel('Count', fontsize=LABEL_SIZE)

if x_lims:
ax.set_xlim(x_lims)

if title:
ax.set_title(title, fontsize=20)
ax.set_title(title, fontsize=TITLE_FONTSIZE)

ax.tick_params(axis='both', labelsize=12)
ax.tick_params(axis='both', labelsize=TICK_LABELSIZE)

0 comments on commit 72acddc

Please sign in to comment.