-
Notifications
You must be signed in to change notification settings - Fork 216
/
train.py
executable file
·436 lines (356 loc) · 18.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
#!/usr/bin/env python3
from argparse import ArgumentParser
from datetime import timedelta
from importlib import import_module
import logging.config
import os
from signal import SIGINT, SIGTERM
import sys
import time
import json
import numpy as np
import tensorflow as tf
from tensorflow.contrib import slim
import common
import lbtoolbox as lb
import loss
from nets import NET_CHOICES
from heads import HEAD_CHOICES
parser = ArgumentParser(description='Train a ReID network.')
# Required.
parser.add_argument(
'--experiment_root', required=True, type=common.writeable_directory,
help='Location used to store checkpoints and dumped data.')
parser.add_argument(
'--train_set',
help='Path to the train_set csv file.')
parser.add_argument(
'--image_root', type=common.readable_directory,
help='Path that will be pre-pended to the filenames in the train_set csv.')
# Optional with sane defaults.
parser.add_argument(
'--resume', action='store_true', default=False,
help='When this flag is provided, all other arguments apart from the '
'experiment_root are ignored and a previously saved set of arguments '
'is loaded.')
parser.add_argument(
'--model_name', default='resnet_v1_50', choices=NET_CHOICES,
help='Name of the model to use.')
parser.add_argument(
'--head_name', default='fc1024', choices=HEAD_CHOICES,
help='Name of the head to use.')
parser.add_argument(
'--embedding_dim', default=128, type=common.positive_int,
help='Dimensionality of the embedding space.')
parser.add_argument(
'--initial_checkpoint', default=None,
help='Path to the checkpoint file of the pretrained network.')
# TODO move these defaults to the .sh script?
parser.add_argument(
'--batch_p', default=32, type=common.positive_int,
help='The number P used in the PK-batches')
parser.add_argument(
'--batch_k', default=4, type=common.positive_int,
help='The numberK used in the PK-batches')
parser.add_argument(
'--net_input_height', default=256, type=common.positive_int,
help='Height of the input directly fed into the network.')
parser.add_argument(
'--net_input_width', default=128, type=common.positive_int,
help='Width of the input directly fed into the network.')
parser.add_argument(
'--pre_crop_height', default=288, type=common.positive_int,
help='Height used to resize a loaded image. This is ignored when no crop '
'augmentation is applied.')
parser.add_argument(
'--pre_crop_width', default=144, type=common.positive_int,
help='Width used to resize a loaded image. This is ignored when no crop '
'augmentation is applied.')
# TODO end
parser.add_argument(
'--loading_threads', default=8, type=common.positive_int,
help='Number of threads used for parallel loading.')
parser.add_argument(
'--margin', default='soft', type=common.float_or_string,
help='What margin to use: a float value for hard-margin, "soft" for '
'soft-margin, or no margin if "none".')
parser.add_argument(
'--metric', default='euclidean', choices=loss.cdist.supported_metrics,
help='Which metric to use for the distance between embeddings.')
parser.add_argument(
'--loss', default='batch_hard', choices=loss.LOSS_CHOICES.keys(),
help='Enable the super-mega-advanced top-secret sampling stabilizer.')
parser.add_argument(
'--learning_rate', default=3e-4, type=common.positive_float,
help='The initial value of the learning-rate, before it kicks in.')
parser.add_argument(
'--train_iterations', default=25000, type=common.positive_int,
help='Number of training iterations.')
parser.add_argument(
'--decay_start_iteration', default=15000, type=int,
help='At which iteration the learning-rate decay should kick-in.'
'Set to -1 to disable decay completely.')
parser.add_argument(
'--checkpoint_frequency', default=1000, type=common.nonnegative_int,
help='After how many iterations a checkpoint is stored. Set this to 0 to '
'disable intermediate storing. This will result in only one final '
'checkpoint.')
parser.add_argument(
'--flip_augment', action='store_true', default=False,
help='When this flag is provided, flip augmentation is performed.')
parser.add_argument(
'--crop_augment', action='store_true', default=False,
help='When this flag is provided, crop augmentation is performed. Based on'
'The `crop_height` and `crop_width` parameters. Changing this flag '
'thus likely changes the network input size!')
parser.add_argument(
'--detailed_logs', action='store_true', default=False,
help='Store very detailed logs of the training in addition to TensorBoard'
' summaries. These are mem-mapped numpy files containing the'
' embeddings, losses and FIDs seen in each batch during training.'
' Everything can be re-constructed and analyzed that way.')
def sample_k_fids_for_pid(pid, all_fids, all_pids, batch_k):
""" Given a PID, select K FIDs of that specific PID. """
possible_fids = tf.boolean_mask(all_fids, tf.equal(all_pids, pid))
# The following simply uses a subset of K of the possible FIDs
# if more than, or exactly K are available. Otherwise, we first
# create a padded list of indices which contain a multiple of the
# original FID count such that all of them will be sampled equally likely.
count = tf.shape(possible_fids)[0]
padded_count = tf.cast(tf.ceil(batch_k / tf.cast(count, tf.float32)), tf.int32) * count
full_range = tf.mod(tf.range(padded_count), count)
# Sampling is always performed by shuffling and taking the first k.
shuffled = tf.random_shuffle(full_range)
selected_fids = tf.gather(possible_fids, shuffled[:batch_k])
return selected_fids, tf.fill([batch_k], pid)
def main():
args = parser.parse_args()
# We store all arguments in a json file. This has two advantages:
# 1. We can always get back and see what exactly that experiment was
# 2. We can resume an experiment as-is without needing to remember all flags.
args_file = os.path.join(args.experiment_root, 'args.json')
if args.resume:
if not os.path.isfile(args_file):
raise IOError('`args.json` not found in {}'.format(args_file))
print('Loading args from {}.'.format(args_file))
with open(args_file, 'r') as f:
args_resumed = json.load(f)
args_resumed['resume'] = True # This would be overwritten.
# When resuming, we not only want to populate the args object with the
# values from the file, but we also want to check for some possible
# conflicts between loaded and given arguments.
for key, value in args.__dict__.items():
if key in args_resumed:
resumed_value = args_resumed[key]
if resumed_value != value:
print('Warning: For the argument `{}` we are using the'
' loaded value `{}`. The provided value was `{}`'
'.'.format(key, resumed_value, value))
args.__dict__[key] = resumed_value
else:
print('Warning: A new argument was added since the last run:'
' `{}`. Using the new value: `{}`.'.format(key, value))
else:
# If the experiment directory exists already, we bail in fear.
if os.path.exists(args.experiment_root):
if os.listdir(args.experiment_root):
print('The directory {} already exists and is not empty.'
' If you want to resume training, append --resume to'
' your call.'.format(args.experiment_root))
exit(1)
else:
os.makedirs(args.experiment_root)
# Store the passed arguments for later resuming and grepping in a nice
# and readable format.
with open(args_file, 'w') as f:
json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True)
log_file = os.path.join(args.experiment_root, "train")
logging.config.dictConfig(common.get_logging_dict(log_file))
log = logging.getLogger('train')
# Also show all parameter values at the start, for ease of reading logs.
log.info('Training using the following parameters:')
for key, value in sorted(vars(args).items()):
log.info('{}: {}'.format(key, value))
# Check them here, so they are not required when --resume-ing.
if not args.train_set:
parser.print_help()
log.error("You did not specify the `train_set` argument!")
sys.exit(1)
if not args.image_root:
parser.print_help()
log.error("You did not specify the required `image_root` argument!")
sys.exit(1)
# Load the data from the CSV file.
pids, fids = common.load_dataset(args.train_set, args.image_root)
max_fid_len = max(map(len, fids)) # We'll need this later for logfiles.
# Setup a tf.Dataset where one "epoch" loops over all PIDS.
# PIDS are shuffled after every epoch and continue indefinitely.
unique_pids = np.unique(pids)
dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
dataset = dataset.shuffle(len(unique_pids))
# Constrain the dataset size to a multiple of the batch-size, so that
# we don't get overlap at the end of each epoch.
dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
dataset = dataset.repeat(None) # Repeat forever. Funny way of stating it.
# For every PID, get K images.
dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))
# Ungroup/flatten the batches for easy loading of the files.
dataset = dataset.apply(tf.contrib.data.unbatch())
# Convert filenames to actual image tensors.
net_input_size = (args.net_input_height, args.net_input_width)
pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
dataset = dataset.map(
lambda fid, pid: common.fid_to_image(
fid, pid, image_root=args.image_root,
image_size=pre_crop_size if args.crop_augment else net_input_size),
num_parallel_calls=args.loading_threads)
# Augment the data if specified by the arguments.
if args.flip_augment:
dataset = dataset.map(
lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid))
if args.crop_augment:
dataset = dataset.map(
lambda im, fid, pid: (tf.random_crop(im, net_input_size + (3,)), fid, pid))
# Group it back into PK batches.
batch_size = args.batch_p * args.batch_k
dataset = dataset.batch(batch_size)
# Overlap producing and consuming for parallelism.
dataset = dataset.prefetch(1)
# Since we repeat the data infinitely, we only need a one-shot iterator.
images, fids, pids = dataset.make_one_shot_iterator().get_next()
# Create the model and an embedding head.
model = import_module('nets.' + args.model_name)
head = import_module('heads.' + args.head_name)
# Feed the image through the model. The returned `body_prefix` will be used
# further down to load the pre-trained weights for all variables with this
# prefix.
endpoints, body_prefix = model.endpoints(images, is_training=True)
with tf.name_scope('head'):
endpoints = head.head(endpoints, args.embedding_dim, is_training=True)
# Create the loss in two steps:
# 1. Compute all pairwise distances according to the specified metric.
# 2. For each anchor along the first dimension, compute its loss.
dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
dists, pids, args.margin, batch_precision_at_k=args.batch_k-1)
# Count the number of active entries, and compute the total batch loss.
num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
loss_mean = tf.reduce_mean(losses)
# Some logging for tensorboard.
tf.summary.histogram('loss_distribution', losses)
tf.summary.scalar('loss', loss_mean)
tf.summary.scalar('batch_top1', train_top1)
tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k-1), prec_at_k)
tf.summary.scalar('active_count', num_active)
tf.summary.histogram('embedding_dists', dists)
tf.summary.histogram('embedding_pos_dists', pos_dists)
tf.summary.histogram('embedding_neg_dists', neg_dists)
tf.summary.histogram('embedding_lengths',
tf.norm(endpoints['emb_raw'], axis=1))
# Create the mem-mapped arrays in which we'll log all training detail in
# addition to tensorboard, because tensorboard is annoying for detailed
# inspection and actually discards data in histogram summaries.
if args.detailed_logs:
log_embs = lb.create_or_resize_dat(
os.path.join(args.experiment_root, 'embeddings'),
dtype=np.float32, shape=(args.train_iterations, batch_size, args.embedding_dim))
log_loss = lb.create_or_resize_dat(
os.path.join(args.experiment_root, 'losses'),
dtype=np.float32, shape=(args.train_iterations, batch_size))
log_fids = lb.create_or_resize_dat(
os.path.join(args.experiment_root, 'fids'),
dtype='S' + str(max_fid_len), shape=(args.train_iterations, batch_size))
# These are collected here before we add the optimizer, because depending
# on the optimizer, it might add extra slots, which are also global
# variables, with the exact same prefix.
model_variables = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, body_prefix)
# Define the optimizer and the learning-rate schedule.
# Unfortunately, we get NaNs if we don't handle no-decay separately.
global_step = tf.Variable(0, name='global_step', trainable=False)
if 0 <= args.decay_start_iteration < args.train_iterations:
learning_rate = tf.train.exponential_decay(
args.learning_rate,
tf.maximum(0, global_step - args.decay_start_iteration),
args.train_iterations - args.decay_start_iteration, 0.001)
else:
learning_rate = args.learning_rate
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate)
# Feel free to try others!
# optimizer = tf.train.AdadeltaOptimizer(learning_rate)
# Update_ops are used to update batchnorm stats.
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_op = optimizer.minimize(loss_mean, global_step=global_step)
# Define a saver for the complete model.
checkpoint_saver = tf.train.Saver(max_to_keep=0)
with tf.Session() as sess:
if args.resume:
# In case we're resuming, simply load the full checkpoint to init.
last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
checkpoint_saver.restore(sess, last_checkpoint)
else:
# But if we're starting from scratch, we may need to load some
# variables from the pre-trained weights, and random init others.
sess.run(tf.global_variables_initializer())
if args.initial_checkpoint is not None:
saver = tf.train.Saver(model_variables)
saver.restore(sess, args.initial_checkpoint)
# In any case, we also store this initialization as a checkpoint,
# such that we could run exactly reproduceable experiments.
checkpoint_saver.save(sess, os.path.join(
args.experiment_root, 'checkpoint'), global_step=0)
merged_summary = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph)
start_step = sess.run(global_step)
log.info('Starting training from iteration {}.'.format(start_step))
# Finally, here comes the main-loop. This `Uninterrupt` is a handy
# utility such that an iteration still finishes on Ctrl+C and we can
# stop the training cleanly.
with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
for i in range(start_step, args.train_iterations):
# Compute gradients, update weights, store logs!
start_time = time.time()
_, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
sess.run([train_op, merged_summary, global_step,
prec_at_k, endpoints['emb'], losses, fids])
elapsed_time = time.time() - start_time
# Compute the iteration speed and add it to the summary.
# We did observe some weird spikes that we couldn't track down.
summary2 = tf.Summary()
summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time)
summary_writer.add_summary(summary2, step)
summary_writer.add_summary(summary, step)
if args.detailed_logs:
log_embs[i], log_loss[i], log_fids[i] = b_embs, b_loss, b_fids
# Do a huge print out of the current progress.
seconds_todo = (args.train_iterations - step) * elapsed_time
log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
step,
float(np.min(b_loss)),
float(np.mean(b_loss)),
float(np.max(b_loss)),
args.batch_k-1, float(b_prec_at_k),
timedelta(seconds=int(seconds_todo)),
elapsed_time))
sys.stdout.flush()
sys.stderr.flush()
# Save a checkpoint of training every so often.
if (args.checkpoint_frequency > 0 and
step % args.checkpoint_frequency == 0):
checkpoint_saver.save(sess, os.path.join(
args.experiment_root, 'checkpoint'), global_step=step)
# Stop the main-loop at the end of the step, if requested.
if u.interrupted:
log.info("Interrupted on request!")
break
# Store one final checkpoint. This might be redundant, but it is crucial
# in case intermediate storing was disabled and it saves a checkpoint
# when the process was interrupted.
checkpoint_saver.save(sess, os.path.join(
args.experiment_root, 'checkpoint'), global_step=step)
if __name__ == '__main__':
main()