Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Add average_reconstructions function #289

Merged
merged 3 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Functions to manipulate, examine and analyze FOOOF objects, and related utilitie

compare_info
average_fg
average_reconstructions
combine_fooofs

.. currentmodule:: fooof
Expand Down
54 changes: 43 additions & 11 deletions fooof/objs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
If there are no model fit results available to average across.
"""

if avg_method not in ['mean', 'median']:
raise ValueError("Requested average method not understood.")
if not fg.has_model:
raise NoModelError("No model fit results are available, can not proceed.")

if avg_method == 'mean':
avg_func = np.nanmean
elif avg_method == 'median':
avg_func = np.nanmedian
avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian}
if avg_method not in avg_funcs.keys():
raise ValueError("Requested average method not understood.")

# Aperiodic parameters: extract & average
ap_params = avg_func(fg.get_params('aperiodic_params'), 0)
ap_params = avg_funcs[avg_method](fg.get_params('aperiodic_params'), 0)

# Periodic parameters: extract & average
peak_params = []
Expand All @@ -90,15 +87,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
# Check if there are any extracted peaks - if not, don't add
# Note that we only check peaks, but gauss should be the same
if not np.all(np.isnan(peaks)):
peak_params.append(avg_func(peaks, 0))
gauss_params.append(avg_func(gauss, 0))
peak_params.append(avg_funcs[avg_method](peaks, 0))
gauss_params.append(avg_funcs[avg_method](gauss, 0))

peak_params = np.array(peak_params)
gauss_params = np.array(gauss_params)

# Goodness of fit measures: extract & average
r2 = avg_func(fg.get_params('r_squared'))
error = avg_func(fg.get_params('error'))
r2 = avg_funcs[avg_method](fg.get_params('r_squared'))
error = avg_funcs[avg_method](fg.get_params('error'))

# Collect all results together, to be added to FOOOF object
results = FOOOFResults(ap_params, peak_params, r2, error, gauss_params)
Expand All @@ -116,6 +113,41 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
return fm


def average_reconstructions(fg, avg_method='mean'):
"""Average across model reconstructions for a group of power spectra.

Parameters
----------
fg : FOOOFGroup
Object with model fit results to average across.
avg : {'mean', 'median'}
Averaging function to use.

Returns
-------
freqs : 1d array
Frequency values for the average model reconstruction.
avg_model : 1d array
Power values for the average model reconstruction.
Note that power values are in log10 space.
"""

if not fg.has_model:
raise NoModelError("No model fit results are available, can not proceed.")

avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian}
if avg_method not in avg_funcs.keys():
raise ValueError("Requested average method not understood.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above checks are duplicated in average_fg and could be moved to a shared check func.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh, I agree! I didn't go for that here, because there are related changes in #283 that add a general function for averaging data, but I didn't bother trying to backport that, and so the plan is to consolidate this into the 2.0 release and clean up the code here when merging!


models = np.zeros(shape=fg.power_spectra.shape)
for ind in range(len(fg)):
models[ind, :] = fg.get_fooof(ind, regenerate=True).fooofed_spectrum_

avg_model = avg_funcs[avg_method](models, 0)

return fg.freqs, avg_model


def combine_fooofs(fooofs):
"""Combine a group of FOOOF and/or FOOOFGroup objects into a single FOOOFGroup object.

Expand Down
3 changes: 1 addition & 2 deletions fooof/plts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""Plots sub-module for FOOOF."""

from .spectra import plot_spectra
from .spectra import plot_spectra as plot_spectrum
from .spectra import plot_spectrum, plot_spectra
4 changes: 4 additions & 0 deletions fooof/plts/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False,
style_spectrum_plot(ax, log_freqs, log_powers)


# Alias `plot_spectrum` to `plot_spectra` for backwards compatibility
plot_spectrum = plot_spectra


@savefig
@check_dependency(plt, 'matplotlib')
def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',
Expand Down
7 changes: 7 additions & 0 deletions fooof/tests/objs/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def test_average_fg(tfg, tbands):
with raises(NoModelError):
average_fg(ntfg, tbands)

def test_average_reconstructions(tfg):

freqs, avg_model = average_reconstructions(tfg)
assert isinstance(freqs, np.ndarray)
assert isinstance(avg_model, np.ndarray)
assert freqs.shape == avg_model.shape

def test_combine_fooofs(tfm, tfg):

tfm2 = tfm.copy()
Expand Down