diff --git a/doc/api.rst b/doc/api.rst index 23f06e5e..1338100d 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -52,6 +52,7 @@ Functions to manipulate, examine and analyze FOOOF objects, and related utilitie compare_info average_fg + average_reconstructions combine_fooofs .. currentmodule:: fooof diff --git a/fooof/objs/utils.py b/fooof/objs/utils.py index fe1b740c..a9ac7934 100644 --- a/fooof/objs/utils.py +++ b/fooof/objs/utils.py @@ -65,18 +65,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True): If there are no model fit results available to average across. """ - if avg_method not in ['mean', 'median']: - raise ValueError("Requested average method not understood.") if not fg.has_model: raise NoModelError("No model fit results are available, can not proceed.") - if avg_method == 'mean': - avg_func = np.nanmean - elif avg_method == 'median': - avg_func = np.nanmedian + avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian} + if avg_method not in avg_funcs.keys(): + raise ValueError("Requested average method not understood.") # Aperiodic parameters: extract & average - ap_params = avg_func(fg.get_params('aperiodic_params'), 0) + ap_params = avg_funcs[avg_method](fg.get_params('aperiodic_params'), 0) # Periodic parameters: extract & average peak_params = [] @@ -90,15 +87,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True): # Check if there are any extracted peaks - if not, don't add # Note that we only check peaks, but gauss should be the same if not np.all(np.isnan(peaks)): - peak_params.append(avg_func(peaks, 0)) - gauss_params.append(avg_func(gauss, 0)) + peak_params.append(avg_funcs[avg_method](peaks, 0)) + gauss_params.append(avg_funcs[avg_method](gauss, 0)) peak_params = np.array(peak_params) gauss_params = np.array(gauss_params) # Goodness of fit measures: extract & average - r2 = avg_func(fg.get_params('r_squared')) - error = avg_func(fg.get_params('error')) + r2 = avg_funcs[avg_method](fg.get_params('r_squared')) + error = avg_funcs[avg_method](fg.get_params('error')) # Collect all results together, to be added to FOOOF object results = FOOOFResults(ap_params, peak_params, r2, error, gauss_params) @@ -116,6 +113,41 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True): return fm +def average_reconstructions(fg, avg_method='mean'): + """Average across model reconstructions for a group of power spectra. + + Parameters + ---------- + fg : FOOOFGroup + Object with model fit results to average across. + avg : {'mean', 'median'} + Averaging function to use. + + Returns + ------- + freqs : 1d array + Frequency values for the average model reconstruction. + avg_model : 1d array + Power values for the average model reconstruction. + Note that power values are in log10 space. + """ + + if not fg.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian} + if avg_method not in avg_funcs.keys(): + raise ValueError("Requested average method not understood.") + + models = np.zeros(shape=fg.power_spectra.shape) + for ind in range(len(fg)): + models[ind, :] = fg.get_fooof(ind, regenerate=True).fooofed_spectrum_ + + avg_model = avg_funcs[avg_method](models, 0) + + return fg.freqs, avg_model + + def combine_fooofs(fooofs): """Combine a group of FOOOF and/or FOOOFGroup objects into a single FOOOFGroup object. diff --git a/fooof/plts/__init__.py b/fooof/plts/__init__.py index 95e05f40..981ba12b 100644 --- a/fooof/plts/__init__.py +++ b/fooof/plts/__init__.py @@ -1,4 +1,3 @@ """Plots sub-module for FOOOF.""" -from .spectra import plot_spectra -from .spectra import plot_spectra as plot_spectrum +from .spectra import plot_spectrum, plot_spectra diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py index c68acc69..8141b7cd 100644 --- a/fooof/plts/spectra.py +++ b/fooof/plts/spectra.py @@ -77,6 +77,10 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, style_spectrum_plot(ax, log_freqs, log_powers) +# Alias `plot_spectrum` to `plot_spectra` for backwards compatibility +plot_spectrum = plot_spectra + + @savefig @check_dependency(plt, 'matplotlib') def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', diff --git a/fooof/tests/objs/test_utils.py b/fooof/tests/objs/test_utils.py index b496c1bc..28a4c87e 100644 --- a/fooof/tests/objs/test_utils.py +++ b/fooof/tests/objs/test_utils.py @@ -45,6 +45,13 @@ def test_average_fg(tfg, tbands): with raises(NoModelError): average_fg(ntfg, tbands) +def test_average_reconstructions(tfg): + + freqs, avg_model = average_reconstructions(tfg) + assert isinstance(freqs, np.ndarray) + assert isinstance(avg_model, np.ndarray) + assert freqs.shape == avg_model.shape + def test_combine_fooofs(tfm, tfg): tfm2 = tfm.copy()