Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] - Technical Audit: Time Frequency edits #199

Merged
merged 10 commits into from
Jul 6, 2020
2 changes: 1 addition & 1 deletion neurodsp/spectral/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def compute_spectrum_wavelet(sig, fs, freqs, avg_type='mean', **kwargs):
freqs = create_freqs(*freqs)

mwt = compute_wavelet_transform(sig, fs, freqs, **kwargs)
spectrum = get_avg_func(avg_type)(mwt, axis=0)
spectrum = get_avg_func(avg_type)(mwt, axis=1)

return freqs, spectrum

Expand Down
8 changes: 3 additions & 5 deletions neurodsp/timefrequency/hilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def robust_hilbert(sig, increase_n=False):
Returns
-------
sig_hilb : 1d array
The Hilbert transform of the input signal.
The analytic signal, of which the imaginary part is the Hilbert transform of the input signal.

Examples
--------
Expand Down Expand Up @@ -193,12 +193,10 @@ def freq_by_time(sig, fs, f_range=None, hilbert_increase_n=False,
... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}})
>>> instant_freq = freq_by_time(sig, fs=500, f_range=(8, 12))
"""

pha = phase_by_time(sig, fs, f_range, hilbert_increase_n,
remove_edges, **filter_kwargs)
pha = np.unwrap(phase_by_time(sig, fs, f_range, hilbert_increase_n,
remove_edges, **filter_kwargs))

phadiff = np.diff(pha)
phadiff[phadiff < 0] = phadiff[phadiff < 0] + 2 * np.pi

i_f = fs * phadiff / (2 * np.pi)
i_f = np.insert(i_f, 0, np.nan)
Expand Down
6 changes: 3 additions & 3 deletions neurodsp/timefrequency/wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
###################################################################################################

@multidim()
def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5):
def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='sss'):
elybrand marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the time-frequency representation of a signal using morlet wavelets.

Parameters
Expand Down Expand Up @@ -48,9 +48,9 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5):
freqs = create_freqs(*freqs)
n_cycles = check_n_cycles(n_cycles, len(freqs))

mwt = np.zeros([len(sig), len(freqs)], dtype=complex)
mwt = np.zeros([len(freqs), len(sig)], dtype=complex)
for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)):
mwt[:, ind] = convolve_wavelet(sig, fs, freq, n_cycle, scaling)
mwt[ind, :] = convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm)

return mwt

Expand Down