-
Notifications
You must be signed in to change notification settings - Fork 10
/
plot_predictions_scatter.py
40 lines (32 loc) · 1.06 KB
/
plot_predictions_scatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# %% imports
import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
DATASETS = ['chbp', 'lemon', 'tuab', 'camcan']
BENCHMARKS = ['dummy', 'filterbank-riemann', 'filterbank-source',
'handcrafted', 'shallow', 'deep']
parser = argparse.ArgumentParser(description='Compute features.')
parser.add_argument(
'-d', '--dataset',
default=None,
nargs='+',
help='the dataset for which features should be computed')
parser.add_argument(
'-b', '--benchmark',
default=None,
nargs='+', help='Type of features to compute')
parsed = parser.parse_args()
datasets = parsed.dataset
benchmarks = parsed.benchmark
if datasets is None:
datasets = DATASETS
if benchmarks is None:
benchmarks = BENCHMARKS
tasks = [(ds, bs) for ds in datasets for bs in benchmarks]
for dataset, benchmark in tasks:
print(f"Plotting for '{benchmark}' on '{dataset}' data")
ys = pd.read_csv(
f"./results/benchmark-{benchmark}_dataset-{dataset}_ys.csv")
sns.scatterplot(x="y_true", y="y_pred", data=ys)
plt.show()