Skip to content

Commit

Permalink
add tests & corresponding updates / fixes for event loading and assoc…
Browse files Browse the repository at this point in the history
…aited
  • Loading branch information
TomDonoghue committed Jul 17, 2023
1 parent 409db0d commit 041067e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 2 deletions.
6 changes: 5 additions & 1 deletion specparam/objs/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,17 @@ def load(self, file_name, file_path=None, peak_org=None):
"""

files = get_files(file_path, select=file_name)
spectrograms = []
for file in files:
super().load(file, file_path, peak_org=False)
if self.group_results:
self.event_group_results.append(self.group_results)
if np.all(self.power_spectra):
spectrograms.append(self.spectrogram)
self.spectrograms = np.array(spectrograms) if spectrograms else None

self._reset_group_results()
if peak_org is not False:
if peak_org is not False and self.event_group_results:
self.convert_results(peak_org)


Expand Down
2 changes: 1 addition & 1 deletion specparam/objs/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def load(self, file_name, file_path=None, peak_org=None):
# Clear results so as not to have possible prior results interfere
self._reset_time_results()
super().load(file_name, file_path=file_path)
if peak_org is not False:
if peak_org is not False and self.group_results:
self.convert_results(peak_org)


Expand Down
16 changes: 16 additions & 0 deletions specparam/tests/core/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ def test_save_time(tft):
assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json'))
assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json'))

def test_save_event(tfe):
"""Check saving fe data."""

res_file_name = 'test_event_res'
set_file_name = 'test_event_set'
dat_file_name = 'test_event_dat'

save_event(tfe, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True)
save_event(tfe, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True)
save_event(tfe, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True)

assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json'))
for ind in range(len(tfe)):
assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '_' + str(ind) + '.json'))
assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '_' + str(ind) + '.json'))

def test_load_json_str():
"""Test loading JSON file, with str file specifier.
Loads files from test_save_model_str.
Expand Down
21 changes: 21 additions & 0 deletions specparam/tests/objs/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,27 @@ def test_event_report(skip_if_no_mpl):

assert tfe

def test_event_load(tbands):

file_name_res = 'test_event_res'
file_name_set = 'test_event_set'
file_name_dat = 'test_event_dat'

# Test loading results
tfe = SpectralTimeEventModel(verbose=False)
tfe.load(file_name_res, TEST_DATA_PATH, peak_org=tbands)
assert tfe.event_time_results

# Test loading settings
tfe = SpectralTimeEventModel(verbose=False)
tfe.load(file_name_set, TEST_DATA_PATH)
assert tfe.get_settings()

# Test loading data
tfe = SpectralTimeEventModel(verbose=False)
tfe.load(file_name_dat, TEST_DATA_PATH)
assert np.all(tfe.spectrograms)

def test_event_get_model(tfe):

# Check without regenerating
Expand Down
10 changes: 10 additions & 0 deletions specparam/tests/objs/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def test_time_load(tbands):
tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands)
assert tft.time_results

# Test loading settings
tft = SpectralTimeModel(verbose=False)
tft.load(file_name_set, TEST_DATA_PATH)
assert tft.get_settings()

# Test loading data
tft = SpectralTimeModel(verbose=False)
tft.load(file_name_dat, TEST_DATA_PATH)
assert np.all(tft.power_spectra)

def test_time_drop():

n_windows = 3
Expand Down

0 comments on commit 041067e

Please sign in to comment.