-
Notifications
You must be signed in to change notification settings - Fork 0
/
sweep.py
84 lines (73 loc) · 2.52 KB
/
sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import wandb
def parse_args():
parser = argparse.ArgumentParser(prog='python sweep.py',
description='Intialize a wandb hyperparameter sweep for PopGenAdapt for a given dataset and SSDA method.')
parser.add_argument('--data', type=str, required=True,
help='Path to the dataset JSON file.')
parser.add_argument('--mme', action='store_true',
help='Whether to use minimax entropy.')
parser.add_argument('--sla', action='store_true',
help='Whether to use source label adaptation.')
parser.add_argument('--suffix', type=str, default='', # useful for different datasets
help='Suffix to append to the project name.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
project_name = f"PopGenAdapt-{'base' if not args.mme else 'mme'}{('-sla' if args.sla else '')}{'-' + args.suffix if args.suffix else ''}"
sweep_config = {
'program': 'main.py',
'method': 'random',
'metric': {
'name': 'best_val_auc',
'goal': 'maximize'
},
'parameters': {
'lr': {
'distribution': 'log_uniform_values',
'min': 1e-5,
'max': 1e-2
},
},
'command': [
'${env}',
'python3',
'${program}',
'--verbose',
'--project',
project_name,
'--seed',
'42',
'--data',
args.data,
] + ([] if not args.mme else ['--mme'])
+ ([] if not args.sla else ['--sla'])
+ ['${args}']
}
if args.sla:
sweep_config['parameters'].update({
'mme_lambda': {
'distribution': 'uniform',
'min': 0.0,
'max': 1.0
},
'sla_warmup': {
'values': [100, 500, 1000, 2000, 5000]
},
'sla_temperature': {
'distribution': 'uniform',
'min': 0.0,
'max': 1.0
},
'sla_alpha': {
'distribution': 'uniform',
'min': 0.0,
'max': 1.0
},
'sla_update_interval': {
'values': [5, 10, 100, 500, 1000, 2000, 5000]
}
})
sweep_id = wandb.sweep(sweep_config,
project=project_name)
wandb.agent(sweep_id)