Skip to content

Commit

Permalink
Merge pull request #290 from fooof-tools/paramshade
Browse files Browse the repository at this point in the history
[ENH] - Add options to shade param reconstructions
  • Loading branch information
TomDonoghue authored Jul 21, 2023
2 parents 14fbde0 + 921a064 commit d7e7165
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 34 deletions.
35 changes: 23 additions & 12 deletions specparam/plts/aperiodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from specparam.sim.gen import gen_freqs, gen_aperiodic
from specparam.core.modutils import safe_import, check_dependency
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs

Expand Down Expand Up @@ -62,6 +63,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs)
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_aperiodic_fits(aps, freq_range, control_offset=False,
average='mean', shade='sem', plot_individual=True,
log_freqs=False, colors=None, labels=None,
ax=None, **plot_kwargs):
"""Plot reconstructions of model aperiodic fits.
Expand All @@ -72,6 +74,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp].
freq_range : list of [float, float]
The frequency range to plot the peak fits across, as [f_min, f_max].
average : {'mean', 'median'}, optional, default: 'mean'
Approach to take to average across components.
If set to None, no average is plotted.
shade : {'sem', 'std'}, optional, default: 'sem'
Approach for shading above/below the average reconstruction
If set to None, no yshade is plotted.
plot_individual : bool, optional, default: True
Whether to plot individual component reconstructions.
If False, only the average component reconstruction is plotted.
control_offset : boolean, optional, default: False
Whether to control for the offset, by setting it to zero.
log_freqs : boolean, optional, default: False
Expand Down Expand Up @@ -103,28 +114,28 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,

colors = colors[0] if isinstance(colors, list) else colors

avg_vals = np.zeros(shape=[len(freqs)])

for ap_params in aps:
all_ap_vals = np.zeros(shape=(len(aps), len(freqs)))
for ind, ap_params in enumerate(aps):

if control_offset:

# Copy the object to not overwrite any data
ap_params = ap_params.copy()
ap_params[0] = 0

# Recreate & plot the aperiodic component from parameters
# Create & collect the aperiodic component model from parameters
ap_vals = gen_aperiodic(freqs, ap_params)
all_ap_vals[ind, :] = ap_vals

ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)

# Collect a running average across components
avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0)
if plot_individual:
ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)

# Plot the average component
avg = avg_vals / aps.shape[0]
avg_color = 'black' if not colors else colors
ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels)
# Plot the average across all components
if average is not False:
avg_color = 'black' if not colors else colors
plot_yshade(freqs, all_ap_vals, average=average, shade=shade,
shade_alpha=plot_kwargs.pop('shade_alpha', 0.15),
color=avg_color, linewidth=3.75, label=labels, ax=ax)

# Add axis labels
ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency')
Expand Down
34 changes: 23 additions & 11 deletions specparam/plts/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from specparam.core.funcs import gaussian_function
from specparam.core.modutils import safe_import, check_dependency
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs

Expand Down Expand Up @@ -69,7 +70,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None,

@savefig
@style_plot
def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs):
def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_individual=True,
colors=None, labels=None, ax=None, **plot_kwargs):
"""Plot reconstructions of model peak fits.
Parameters
Expand All @@ -79,6 +81,15 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **
freq_range : list of [float, float] , optional
The frequency range to plot the peak fits across, as [f_min, f_max].
If not provided, defaults to +/- 4 around given peak center frequencies.
average : {'mean', 'median'}, optional, default: 'mean'
Approach to take to average across components.
If set to None, no average is plotted.
shade : {'sem', 'std'}, optional, default: 'sem'
Approach for shading above/below the average reconstruction
If set to None, no yshade is plotted.
plot_individual : bool, optional, default: True
Whether to plot individual component reconstructions.
If False, only the average component reconstruction is plotted.
colors : str or list of str, optional
Color(s) to plot data.
labels : list of str, optional
Expand Down Expand Up @@ -118,21 +129,22 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **

colors = colors[0] if isinstance(colors, list) else colors

avg_vals = np.zeros(shape=[len(freqs)])
all_peak_vals = np.zeros(shape=(len(peaks), len(freqs)))
for ind, peak_params in enumerate(peaks):

for peak_params in peaks:

# Create & plot the peak model from parameters
# Create & collect the peak model from parameters
peak_vals = gaussian_function(freqs, *peak_params)
ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25)
all_peak_vals[ind, :] = peak_vals

# Collect a running average average peaks
avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0)
if plot_individual:
ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25)

# Plot the average across all components
avg = avg_vals / peaks.shape[0]
avg_color = 'black' if not colors else colors
ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels)
if average is not False:
avg_color = 'black' if not colors else colors
plot_yshade(freqs, all_peak_vals, average=average, shade=shade,
shade_alpha=plot_kwargs.pop('shade_alpha', 0.15),
color=avg_color, linewidth=3.75, label=labels, ax=ax)

# Add axis labels
ax.set_xlabel('Frequency')
Expand Down
27 changes: 16 additions & 11 deletions specparam/plts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,16 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non
Data values to be plotted on the y-axis. `shade` must be provided if 1d.
average : 'mean', 'median' or callable, optional, default: 'mean'
Averaging approach for plotting the average. Only used if y_vals is 2d.
If set to None, no average line is plotted.
shade : 'std', 'sem', 1d array or callable, optional, default: 'std'
Approach for shading above/below the average.
If set to None, no shading is plotted.
scale : float, optional, default: 1.
Factor to multiply the plotted shade by.
color : str, optional, default: None
Color to plot.
plot_function : callable, optional
xx
Function to use to create the plot.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Expand All @@ -168,18 +170,21 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non

shade_alpha = plot_kwargs.pop('shade_alpha', 0.25)

avg_data = compute_average(y_vals, average=average)
if plot_function:
plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs)
else:
ax.plot(x_vals, avg_data, color=color, **plot_kwargs)
avg_data = compute_average(y_vals, average=average if average else 'mean')

# Compute shade values and apply scaling
shade_vals = compute_dispersion(y_vals, shade) * scale
if average is not None:

# Plot +/- y-shading around spectrum
ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals,
alpha=shade_alpha, color=color)
if plot_function:
plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs)
else:
ax.plot(x_vals, avg_data, color=color, **plot_kwargs)

if shade is not None:

# Compute shade values, apply scaling & plot +/- y-shading
shade_vals = compute_dispersion(y_vals, shade) * scale
ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals,
alpha=shade_alpha, color=color)


@check_dependency(plt, 'matplotlib')
Expand Down

0 comments on commit d7e7165

Please sign in to comment.