Skip to content

Commit

Permalink
Merge pull request #201 from fooof-tools/optimize
Browse files Browse the repository at this point in the history
[ENH] Optimization
  • Loading branch information
TomDonoghue authored Mar 28, 2021
2 parents f81ee12 + c2c8643 commit 7c75d64
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 20 deletions.
14 changes: 9 additions & 5 deletions fooof/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def group_three(vec):
Parameters
----------
vec : 1d array
Array of items to group by 3. Length of array must be divisible by three.
vec : list or 1d array
List or array of items to group by 3. Length of array must be divisible by three.
Returns
-------
list of list
List of lists, each with three items.
array or list of list
Array or list of lists, each with three items. Output type will match input type.
Raises
------
Expand All @@ -30,7 +30,11 @@ def group_three(vec):
if len(vec) % 3 != 0:
raise ValueError("Wrong size array to group by three.")

return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]
# Reshape, if an array, as it's faster, otherwise asssume lise
if isinstance(vec, np.ndarray):
return np.reshape(vec, (-1, 3))
else:
return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]


def nearest_ind(array, value):
Expand Down
14 changes: 6 additions & 8 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,18 +1005,16 @@ def _create_peak_params(self, gaus_params):
with `freqs`, `fooofed_spectrum_` and `_ap_fit` all required to be available.
"""

peak_params = np.empty([0, 3])
peak_params = np.empty((len(gaus_params), 3))

for ii, peak in enumerate(gaus_params):

# Gets the index of the power_spectrum at the frequency closest to the CF of the peak
ind = min(range(len(self.freqs)), key=lambda ii: abs(self.freqs[ii] - peak[0]))
ind = np.argmin(np.abs(self.freqs - peak[0]))

# Collect peak parameter data
peak_params = np.vstack((peak_params,
[peak[0],
self.fooofed_spectrum_[ind] - self._ap_fit[ind],
peak[2] * 2]))
peak_params[ii] = [peak[0], self.fooofed_spectrum_[ind] - self._ap_fit[ind],
peak[2] * 2]

return peak_params

Expand All @@ -1035,8 +1033,8 @@ def _drop_peak_cf(self, guess):
Guess parameters for gaussian peak fits. Shape: [n_peaks, 3].
"""

cf_params = [item[0] for item in guess]
bw_params = [item[2] * self._bw_std_edge for item in guess]
cf_params = guess[:, 0]
bw_params = guess[:, 2] * self._bw_std_edge

# Check if peaks within drop threshold from the edge of the frequency range
keep_peak = \
Expand Down
13 changes: 9 additions & 4 deletions fooof/objs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,14 @@ def fit_fooof_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):
>>> fgs = fit_fooof_3d(fg, freqs, power_spectra, freq_range=[3, 30]) # doctest:+SKIP
"""

fgs = []
for cond_spectra in power_spectra:
fg.fit(freqs, cond_spectra, freq_range, n_jobs)
fgs.append(fg.copy())
# Reshape 3d data to 2d and fit, in order to fit with a single group model object
shape = np.shape(power_spectra)
powers_2d = np.reshape(power_spectra, (shape[0] * shape[1], shape[2]))

fg.fit(freqs, powers_2d, freq_range, n_jobs)

# Reorganize 2d results into a list of model group objects, to reflect original shape
fgs = [fg.get_group(range(dim_a * shape[1], (dim_a + 1) * shape[1])) \
for dim_a in range(shape[0])]

return fgs
10 changes: 7 additions & 3 deletions fooof/tests/objs/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,17 @@ def test_combine_errors(tfm, tfg):

def test_fit_fooof_3d(tfg):

n_spectra = 2
n_groups = 2
n_spectra = 3
xs, ys = gen_group_power_spectra(n_spectra, *default_group_params())
ys = np.stack([ys, ys], axis=0)
ys = np.stack([ys] * n_groups, axis=0)
spectra_shape = np.shape(ys)

tfg = FOOOFGroup()
fgs = fit_fooof_3d(tfg, xs, ys)

assert len(fgs) == 2
assert len(fgs) == n_groups == spectra_shape[0]
for fg in fgs:
assert fg
assert len(fg) == n_spectra
assert fg.power_spectra.shape == spectra_shape[1:]

0 comments on commit 7c75d64

Please sign in to comment.