From 921a06493f83858bf0c3468879295a6b2c8d8b81 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 18 Jul 2023 22:11:13 -0400 Subject: [PATCH] add shade options to pe & ap params --- specparam/plts/aperiodic.py | 35 +++++++++++++++++++++++------------ specparam/plts/periodic.py | 34 +++++++++++++++++++++++----------- specparam/plts/templates.py | 4 ++-- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/specparam/plts/aperiodic.py b/specparam/plts/aperiodic.py index a32167b5..a57cd02a 100644 --- a/specparam/plts/aperiodic.py +++ b/specparam/plts/aperiodic.py @@ -8,6 +8,7 @@ from specparam.sim.gen import gen_freqs, gen_aperiodic from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -62,6 +63,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs) @style_plot @check_dependency(plt, 'matplotlib') def plot_aperiodic_fits(aps, freq_range, control_offset=False, + average='mean', shade='sem', plot_individual=True, log_freqs=False, colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model aperiodic fits. @@ -72,6 +74,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp]. freq_range : list of [float, float] The frequency range to plot the peak fits across, as [f_min, f_max]. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. control_offset : boolean, optional, default: False Whether to control for the offset, by setting it to zero. log_freqs : boolean, optional, default: False @@ -103,9 +114,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) - - for ap_params in aps: + all_ap_vals = np.zeros(shape=(len(aps), len(freqs))) + for ind, ap_params in enumerate(aps): if control_offset: @@ -113,18 +123,19 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, ap_params = ap_params.copy() ap_params[0] = 0 - # Recreate & plot the aperiodic component from parameters + # Create & collect the aperiodic component model from parameters ap_vals = gen_aperiodic(freqs, ap_params) + all_ap_vals[ind, :] = ap_vals - ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - - # Collect a running average across components - avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0) + if plot_individual: + ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - # Plot the average component - avg = avg_vals / aps.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels) + # Plot the average across all components + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_ap_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency') diff --git a/specparam/plts/periodic.py b/specparam/plts/periodic.py index c2c40f30..c69ba7e3 100644 --- a/specparam/plts/periodic.py +++ b/specparam/plts/periodic.py @@ -8,6 +8,7 @@ from specparam.core.funcs import gaussian_function from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -69,7 +70,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, @savefig @style_plot -def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): +def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_individual=True, + colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model peak fits. Parameters @@ -79,6 +81,15 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** freq_range : list of [float, float] , optional The frequency range to plot the peak fits across, as [f_min, f_max]. If not provided, defaults to +/- 4 around given peak center frequencies. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. colors : str or list of str, optional Color(s) to plot data. labels : list of str, optional @@ -118,21 +129,22 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) + all_peak_vals = np.zeros(shape=(len(peaks), len(freqs))) + for ind, peak_params in enumerate(peaks): - for peak_params in peaks: - - # Create & plot the peak model from parameters + # Create & collect the peak model from parameters peak_vals = gaussian_function(freqs, *peak_params) - ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) + all_peak_vals[ind, :] = peak_vals - # Collect a running average average peaks - avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0) + if plot_individual: + ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) # Plot the average across all components - avg = avg_vals / peaks.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels) + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_peak_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('Frequency') diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 8ece82b0..12f946ea 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -170,9 +170,9 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) - if average is not None: + avg_data = compute_average(y_vals, average=average if average else 'mean') - avg_data = compute_average(y_vals, average=average) + if average is not None: if plot_function: plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs)