Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Add options to shade param reconstructions #290

Merged
merged 2 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions specparam/plts/aperiodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from specparam.sim.gen import gen_freqs, gen_aperiodic
from specparam.core.modutils import safe_import, check_dependency
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs

Expand Down Expand Up @@ -62,6 +63,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs)
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_aperiodic_fits(aps, freq_range, control_offset=False,
average='mean', shade='sem', plot_individual=True,
log_freqs=False, colors=None, labels=None,
ax=None, **plot_kwargs):
"""Plot reconstructions of model aperiodic fits.
Expand All @@ -72,6 +74,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp].
freq_range : list of [float, float]
The frequency range to plot the peak fits across, as [f_min, f_max].
average : {'mean', 'median'}, optional, default: 'mean'
Approach to take to average across components.
If set to None, no average is plotted.
shade : {'sem', 'std'}, optional, default: 'sem'
Approach for shading above/below the average reconstruction
If set to None, no yshade is plotted.
plot_individual : bool, optional, default: True
Whether to plot individual component reconstructions.
If False, only the average component reconstruction is plotted.
control_offset : boolean, optional, default: False
Whether to control for the offset, by setting it to zero.
log_freqs : boolean, optional, default: False
Expand Down Expand Up @@ -103,28 +114,28 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,

colors = colors[0] if isinstance(colors, list) else colors

avg_vals = np.zeros(shape=[len(freqs)])

for ap_params in aps:
all_ap_vals = np.zeros(shape=(len(aps), len(freqs)))
for ind, ap_params in enumerate(aps):

if control_offset:

# Copy the object to not overwrite any data
ap_params = ap_params.copy()
ap_params[0] = 0

# Recreate & plot the aperiodic component from parameters
# Create & collect the aperiodic component model from parameters
ap_vals = gen_aperiodic(freqs, ap_params)
all_ap_vals[ind, :] = ap_vals

ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)

# Collect a running average across components
avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0)
if plot_individual:
ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)

# Plot the average component
avg = avg_vals / aps.shape[0]
avg_color = 'black' if not colors else colors
ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels)
# Plot the average across all components
if average is not False:
avg_color = 'black' if not colors else colors
plot_yshade(freqs, all_ap_vals, average=average, shade=shade,
shade_alpha=plot_kwargs.pop('shade_alpha', 0.15),
color=avg_color, linewidth=3.75, label=labels, ax=ax)

# Add axis labels
ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency')
Expand Down
34 changes: 23 additions & 11 deletions specparam/plts/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from specparam.core.funcs import gaussian_function
from specparam.core.modutils import safe_import, check_dependency
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs

Expand Down Expand Up @@ -69,7 +70,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None,

@savefig
@style_plot
def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs):
def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_individual=True,
colors=None, labels=None, ax=None, **plot_kwargs):
"""Plot reconstructions of model peak fits.

Parameters
Expand All @@ -79,6 +81,15 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **
freq_range : list of [float, float] , optional
The frequency range to plot the peak fits across, as [f_min, f_max].
If not provided, defaults to +/- 4 around given peak center frequencies.
average : {'mean', 'median'}, optional, default: 'mean'
Approach to take to average across components.
If set to None, no average is plotted.
shade : {'sem', 'std'}, optional, default: 'sem'
Approach for shading above/below the average reconstruction
If set to None, no yshade is plotted.
plot_individual : bool, optional, default: True
Whether to plot individual component reconstructions.
If False, only the average component reconstruction is plotted.
colors : str or list of str, optional
Color(s) to plot data.
labels : list of str, optional
Expand Down Expand Up @@ -118,21 +129,22 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **

colors = colors[0] if isinstance(colors, list) else colors

avg_vals = np.zeros(shape=[len(freqs)])
all_peak_vals = np.zeros(shape=(len(peaks), len(freqs)))
for ind, peak_params in enumerate(peaks):

for peak_params in peaks:

# Create & plot the peak model from parameters
# Create & collect the peak model from parameters
peak_vals = gaussian_function(freqs, *peak_params)
ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25)
all_peak_vals[ind, :] = peak_vals

# Collect a running average average peaks
avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0)
if plot_individual:
ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25)

# Plot the average across all components
avg = avg_vals / peaks.shape[0]
avg_color = 'black' if not colors else colors
ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels)
if average is not False:
avg_color = 'black' if not colors else colors
plot_yshade(freqs, all_peak_vals, average=average, shade=shade,
shade_alpha=plot_kwargs.pop('shade_alpha', 0.15),
color=avg_color, linewidth=3.75, label=labels, ax=ax)

# Add axis labels
ax.set_xlabel('Frequency')
Expand Down
27 changes: 16 additions & 11 deletions specparam/plts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,16 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non
Data values to be plotted on the y-axis. `shade` must be provided if 1d.
average : 'mean', 'median' or callable, optional, default: 'mean'
Averaging approach for plotting the average. Only used if y_vals is 2d.
If set to None, no average line is plotted.
shade : 'std', 'sem', 1d array or callable, optional, default: 'std'
Approach for shading above/below the average.
If set to None, no shading is plotted.
scale : float, optional, default: 1.
Factor to multiply the plotted shade by.
color : str, optional, default: None
Color to plot.
plot_function : callable, optional
xx
Function to use to create the plot.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Expand All @@ -168,18 +170,21 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non

shade_alpha = plot_kwargs.pop('shade_alpha', 0.25)

avg_data = compute_average(y_vals, average=average)
if plot_function:
plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs)
else:
ax.plot(x_vals, avg_data, color=color, **plot_kwargs)
avg_data = compute_average(y_vals, average=average if average else 'mean')

# Compute shade values and apply scaling
shade_vals = compute_dispersion(y_vals, shade) * scale
if average is not None:

# Plot +/- y-shading around spectrum
ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals,
alpha=shade_alpha, color=color)
if plot_function:
plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs)
else:
ax.plot(x_vals, avg_data, color=color, **plot_kwargs)

if shade is not None:

# Compute shade values, apply scaling & plot +/- y-shading
shade_vals = compute_dispersion(y_vals, shade) * scale
ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals,
alpha=shade_alpha, color=color)


@check_dependency(plt, 'matplotlib')
Expand Down