Skip to content

Commit

Permalink
add shade options to pe & ap params
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Jul 19, 2023
1 parent aba8941 commit 921a064
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 25 deletions.
35 changes: 23 additions & 12 deletions specparam/plts/aperiodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from specparam.sim.gen import gen_freqs, gen_aperiodic
from specparam.core.modutils import safe_import, check_dependency
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs

Expand Down Expand Up @@ -62,6 +63,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs)
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_aperiodic_fits(aps, freq_range, control_offset=False,
average='mean', shade='sem', plot_individual=True,
log_freqs=False, colors=None, labels=None,
ax=None, **plot_kwargs):
"""Plot reconstructions of model aperiodic fits.
Expand All @@ -72,6 +74,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp].
freq_range : list of [float, float]
The frequency range to plot the peak fits across, as [f_min, f_max].
average : {'mean', 'median'}, optional, default: 'mean'
Approach to take to average across components.
If set to None, no average is plotted.
shade : {'sem', 'std'}, optional, default: 'sem'
Approach for shading above/below the average reconstruction
If set to None, no yshade is plotted.
plot_individual : bool, optional, default: True
Whether to plot individual component reconstructions.
If False, only the average component reconstruction is plotted.
control_offset : boolean, optional, default: False
Whether to control for the offset, by setting it to zero.
log_freqs : boolean, optional, default: False
Expand Down Expand Up @@ -103,28 +114,28 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,

colors = colors[0] if isinstance(colors, list) else colors

avg_vals = np.zeros(shape=[len(freqs)])

for ap_params in aps:
all_ap_vals = np.zeros(shape=(len(aps), len(freqs)))
for ind, ap_params in enumerate(aps):

if control_offset:

# Copy the object to not overwrite any data
ap_params = ap_params.copy()
ap_params[0] = 0

# Recreate & plot the aperiodic component from parameters
# Create & collect the aperiodic component model from parameters
ap_vals = gen_aperiodic(freqs, ap_params)
all_ap_vals[ind, :] = ap_vals

ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)

# Collect a running average across components
avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0)
if plot_individual:
ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)

# Plot the average component
avg = avg_vals / aps.shape[0]
avg_color = 'black' if not colors else colors
ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels)
# Plot the average across all components
if average is not False:
avg_color = 'black' if not colors else colors
plot_yshade(freqs, all_ap_vals, average=average, shade=shade,
shade_alpha=plot_kwargs.pop('shade_alpha', 0.15),
color=avg_color, linewidth=3.75, label=labels, ax=ax)

# Add axis labels
ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency')
Expand Down
34 changes: 23 additions & 11 deletions specparam/plts/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from specparam.core.funcs import gaussian_function
from specparam.core.modutils import safe_import, check_dependency
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs

Expand Down Expand Up @@ -69,7 +70,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None,

@savefig
@style_plot
def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs):
def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_individual=True,
colors=None, labels=None, ax=None, **plot_kwargs):
"""Plot reconstructions of model peak fits.
Parameters
Expand All @@ -79,6 +81,15 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **
freq_range : list of [float, float] , optional
The frequency range to plot the peak fits across, as [f_min, f_max].
If not provided, defaults to +/- 4 around given peak center frequencies.
average : {'mean', 'median'}, optional, default: 'mean'
Approach to take to average across components.
If set to None, no average is plotted.
shade : {'sem', 'std'}, optional, default: 'sem'
Approach for shading above/below the average reconstruction
If set to None, no yshade is plotted.
plot_individual : bool, optional, default: True
Whether to plot individual component reconstructions.
If False, only the average component reconstruction is plotted.
colors : str or list of str, optional
Color(s) to plot data.
labels : list of str, optional
Expand Down Expand Up @@ -118,21 +129,22 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **

colors = colors[0] if isinstance(colors, list) else colors

avg_vals = np.zeros(shape=[len(freqs)])
all_peak_vals = np.zeros(shape=(len(peaks), len(freqs)))
for ind, peak_params in enumerate(peaks):

for peak_params in peaks:

# Create & plot the peak model from parameters
# Create & collect the peak model from parameters
peak_vals = gaussian_function(freqs, *peak_params)
ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25)
all_peak_vals[ind, :] = peak_vals

# Collect a running average average peaks
avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0)
if plot_individual:
ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25)

# Plot the average across all components
avg = avg_vals / peaks.shape[0]
avg_color = 'black' if not colors else colors
ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels)
if average is not False:
avg_color = 'black' if not colors else colors
plot_yshade(freqs, all_peak_vals, average=average, shade=shade,
shade_alpha=plot_kwargs.pop('shade_alpha', 0.15),
color=avg_color, linewidth=3.75, label=labels, ax=ax)

# Add axis labels
ax.set_xlabel('Frequency')
Expand Down
4 changes: 2 additions & 2 deletions specparam/plts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non

shade_alpha = plot_kwargs.pop('shade_alpha', 0.25)

if average is not None:
avg_data = compute_average(y_vals, average=average if average else 'mean')

avg_data = compute_average(y_vals, average=average)
if average is not None:

if plot_function:
plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs)
Expand Down

0 comments on commit 921a064

Please sign in to comment.