Skip to content

Commit

Permalink
use plot_text in report generation
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Jul 15, 2023
1 parent b947748 commit 7a08faf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 40 deletions.
48 changes: 9 additions & 39 deletions specparam/core/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
gen_group_results_str, gen_time_results_str,
gen_event_results_str)
from specparam.data.utils import get_periodic_labels
from specparam.plts.templates import plot_text
from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness,
plot_group_peak_frequencies)

Expand All @@ -17,9 +18,6 @@

## Settings & Globals
REPORT_FIGSIZE = (16, 20)
REPORT_FONT = {'family': 'monospace',
'weight': 'normal',
'size': 16}
SAVE_FORMAT = 'pdf'

###################################################################################################
Expand Down Expand Up @@ -52,23 +50,15 @@ def save_model_report(model, file_name, file_path=None, add_settings=True, **plo
grid = gridspec.GridSpec(n_rows, 1, hspace=0.25, height_ratios=height_ratios)

# First - text results
ax0 = plt.subplot(grid[0])
results_str = gen_model_results_str(model)
ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center')
ax0.set_frame_on(False)
ax0.set(xticks=[], yticks=[])
plot_text(gen_model_results_str(model), 0.5, 0.7, ax=plt.subplot(grid[0]))

# Second - data plot
ax1 = plt.subplot(grid[1])
model.plot(ax=ax1, **plot_kwargs)

# Third - model settings
if add_settings:
ax2 = plt.subplot(grid[2])
settings_str = gen_settings_str(model, False)
ax2.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center')
ax2.set_frame_on(False)
ax2.set(xticks=[], yticks=[])
plot_text(gen_settings_str(model, False), 0.5, 0.1, ax=plt.subplot(grid[2]))

# Save out the report
plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT)))
Expand Down Expand Up @@ -100,11 +90,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True):
grid = gridspec.GridSpec(n_rows, 2, wspace=0.4, hspace=0.25, height_ratios=height_ratios)

# First / top: text results
ax0 = plt.subplot(grid[0, :])
results_str = gen_group_results_str(group)
ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center')
ax0.set_frame_on(False)
ax0.set(xticks=[], yticks=[])
plot_text(gen_group_results_str(group), 0.5, 0.7, ax=plt.subplot(grid[0, :]))

# Second - data plots

Expand All @@ -122,11 +108,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True):

# Third - Model settings
if add_settings:
ax4 = plt.subplot(grid[3, :])
settings_str = gen_settings_str(group, False)
ax4.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center')
ax4.set_frame_on(False)
ax4.set(xticks=[], yticks=[])
plot_text(gen_settings_str(group, False), 0.5, 0.1, ax=plt.subplot(grid[3, :]))

# Save out the report
plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT)))
Expand Down Expand Up @@ -161,20 +143,14 @@ def save_time_report(time_model, file_name, file_path=None, add_settings=True):
figsize=REPORT_FIGSIZE)

# First / top: text results
results_str = gen_time_results_str(time_model)
axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center')
axes[0].set_frame_on(False)
axes[0].set(xticks=[], yticks=[])
plot_text(gen_time_results_str(time_model), 0.5, 0.7, ax=axes[0])

# Second - data plots
time_model.plot(axes=axes[1:2+n_bands+1])

# Third - Model settings
if add_settings:
settings_str = gen_settings_str(time_model, False)
axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center')
axes[-1].set_frame_on(False)
axes[-1].set(xticks=[], yticks=[])
plot_text(gen_settings_str(time_model, False), 0.5, 0.1, ax=axes[-1])

# Save out the report
plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT)))
Expand Down Expand Up @@ -211,20 +187,14 @@ def save_event_report(event_model, file_name, file_path=None, add_settings=True)
figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 6))

# First / top: text results
results_str = gen_event_results_str(event_model)
axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center')
axes[0].set_frame_on(False)
axes[0].set(xticks=[], yticks=[])
plot_text(gen_event_results_str(event_model), 0.5, 0.7, ax=axes[0])

# Second - data plots
event_model.plot(axes=axes[1:-1])

# Third - Model settings
if add_settings:
settings_str = gen_settings_str(event_model, False)
axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center')
axes[-1].set_frame_on(False)
axes[-1].set(xticks=[], yticks=[])
plot_text(gen_settings_str(event_model, False), 0.5, 0.1, ax=axes[-1])

# Save out the report
plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT)))
Expand Down
2 changes: 1 addition & 1 deletion specparam/tests/tutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_tfe():
xs, ys = sim_spectrogram(n_spectra, *default_group_params())
ys = [ys, ys]

bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)})
bands = Bands({'alpha' : (7, 14)})
tfe = SpectralTimeEventModel(verbose=False)
tfe.fit(xs, ys, peak_org=bands)

Expand Down

0 comments on commit 7a08faf

Please sign in to comment.