Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling, batch_all, non-zero, optimizer-flag #33

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 92 additions & 19 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,19 @@ def get_at_indices(tensor, indices):
return tf.gather_nd(tensor, tf.stack((counter, indices), -1))


def batch_hard(dists, pids, margin, batch_precision_at_k=None):
def apply_margin(x, margin):
if isinstance(margin, numbers.Real):
return tf.maximum(x + margin, 0.0)
elif margin == 'soft':
return tf.nn.softplus(x)
elif margin.lower() == 'none':
return x
else:
raise NotImplementedError(
'The margin {} is not implemented in batch_hard'.format(margin))


def _generic_batchloss(dists, pids, margin, batch_precision_at_k=None, variant='hard'):
"""Computes the batch-hard loss from arxiv.org/abs/1703.07737.

Args:
Expand All @@ -87,25 +99,83 @@ def batch_hard(dists, pids, margin, batch_precision_at_k=None):
positive_mask = tf.logical_xor(same_identity_mask,
tf.eye(tf.shape(pids)[0], dtype=tf.bool))

furthest_positive = tf.reduce_max(dists*tf.cast(positive_mask, tf.float32), axis=1)
closest_negative = tf.map_fn(lambda x: tf.reduce_min(tf.boolean_mask(x[0], x[1])),
(dists, negative_mask), tf.float32)
# Another way of achieving the same, though more hacky:
# closest_negative = tf.reduce_min(dists + 1e5*tf.cast(same_identity_mask, tf.float32), axis=1)

diff = furthest_positive - closest_negative
if isinstance(margin, numbers.Real):
diff = tf.maximum(diff + margin, 0.0)
elif margin == 'soft':
diff = tf.nn.softplus(diff)
elif margin.lower() == 'none':
pass
else:
raise NotImplementedError(
'The margin {} is not implemented in batch_hard'.format(margin))
if variant == 'sample':
# -inf gives that index a probability of zero.
neg_infs = -tf.constant(float('inf'))*tf.ones_like(dists)
# higher logits are more likely to be sampled.
pos_logits = tf.where(positive_mask, dists, neg_infs)
pos_indices = tf.multinomial(pos_logits, num_samples=1)[:,0]
positive = get_at_indices(dists, pos_indices)

# Same for the negatives, but we need to turn the logits around,
# since we want to sample the smaller distances more likely.
neg_logits = tf.where(negative_mask, -dists, neg_infs)
neg_indices = tf.multinomial(neg_logits, num_samples=1)[:,0]
negative = get_at_indices(dists, neg_indices)
elif variant == 'hard':
# Furthest one is worst positive.
positive = tf.reduce_max(dists*tf.cast(positive_mask, tf.float32), axis=1)
# Closest one is worst negative.
negative = tf.map_fn(lambda x: tf.reduce_min(tf.boolean_mask(x[0], x[1])),
(dists, negative_mask), tf.float32)
# negative = tf.reduce_min(dists + 1e5*tf.cast(same_identity_mask, tf.float32), axis=1)

losses = apply_margin(positive - negative, margin)

return return_with_extra_stats(losses, dists, batch_precision_at_k,
same_identity_mask,
positive_mask, negative_mask)

def batch_hard(dists, pids, margin, batch_precision_at_k=None):
return _generic_batchloss(dists, pids, margin, batch_precision_at_k, variant='hard')


def batch_sample(dists, pids, margin, batch_precision_at_k=None):
return _generic_batchloss(dists, pids, margin, batch_precision_at_k, variant='sample')


def batch_all(dists, pids, margin, batch_precision_at_k=None):
with tf.name_scope("batch_hard"):
same_identity_mask = tf.equal(tf.expand_dims(pids, axis=1),
tf.expand_dims(pids, axis=0))
negative_mask = tf.logical_not(same_identity_mask)
positive_mask = tf.logical_xor(same_identity_mask,
tf.eye(tf.shape(pids)[0], dtype=tf.bool))

# Unfortunately, foldl can only go over one tensor, unlike map_fn,
# so we need to convert and stack around.
packed = tf.stack([dists,
tf.cast(positive_mask, tf.float32),
tf.cast(negative_mask, tf.float32)], axis=1)

def per_anchor(accum, row):
# `dists_` is a 1D array of distance (row of `dists`)
# `poss_` is a 1D bool array marking positives.
# `negs_` is a 1D bool array marking negatives.
dists_, poss_, negs_ = row[0], row[1], row[2]

# Now construct a (P,N)-matrix of all-to-all (anchor-pos - anchor-neg).
diff = all_diffs(tf.boolean_mask(dists_, tf.cast(poss_, tf.bool)),
tf.boolean_mask(dists_, tf.cast(negs_, tf.bool)))

losses = tf.reshape(apply_margin(diff, margin), [-1])
return tf.concat([accum, losses], axis=0)

# Some very advanced trickery in order to get the initialization tensor
# to be an empty 1D tensor with a dynamic shape, such that it is
# allowed to grow during the iteration.
init = tf.placeholder_with_default([], shape=[None])
losses = tf.foldl(per_anchor, packed, init)

return return_with_extra_stats(losses, dists, batch_precision_at_k,
same_identity_mask,
positive_mask, negative_mask)


def return_with_extra_stats(to_return, dists, batch_precision_at_k,
same_identity_mask, positive_mask, negative_mask):
if batch_precision_at_k is None:
return diff
return to_return

# For monitoring, compute the within-batch top-1 accuracy and the
# within-batch precision-at-k, which is somewhat more expressive.
Expand Down Expand Up @@ -142,9 +212,12 @@ def batch_hard(dists, pids, margin, batch_precision_at_k=None):
negative_dists = tf.boolean_mask(dists, negative_mask)
positive_dists = tf.boolean_mask(dists, positive_mask)

return diff, top1, prec_at_k, topk_is_same, negative_dists, positive_dists
return to_return, top1, prec_at_k, topk_is_same, negative_dists, positive_dists



LOSS_CHOICES = {
'batch_hard': batch_hard,
'batch_sample': batch_sample,
'batch_all': batch_all,
}
31 changes: 23 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@
help='Which metric to use for the distance between embeddings.')

parser.add_argument(
'--loss', default='batch_hard', choices=loss.LOSS_CHOICES.keys(),
'--loss', default='batch_hard', choices=loss.LOSS_CHOICES,
help='Enable the super-mega-advanced top-secret sampling stabilizer.')

parser.add_argument(
'--loss_ignore_zero', default=False, const=True, nargs='?', type=common.positive_float,
Copy link
Contributor

@maxisme maxisme Apr 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh misread this to mean it can only be boolean. I am going to start playing with this then. 🍾

help='Average only over non-zero loss values, called "=/=0" in the paper.')

parser.add_argument(
'--learning_rate', default=3e-4, type=common.positive_float,
help='The initial value of the learning-rate, before it kicks in.')
Expand Down Expand Up @@ -141,6 +145,11 @@
' embeddings, losses and FIDs seen in each batch during training.'
' Everything can be re-constructed and analyzed that way.')

parser.add_argument(
'--optim', default='AdamOptimizer(learning_rate)',
help='Which optimizer to use. This is actual TensorFlow code that will be'
' eval\'d. Use `learning_rate` for the learning-rate with schedule.')


def sample_k_fids_for_pid(pid, all_fids, all_pids, batch_k):
""" Given a PID, select K FIDs of that specific PID. """
Expand Down Expand Up @@ -294,16 +303,24 @@ def main():
losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
dists, pids, args.margin, batch_precision_at_k=args.batch_k-1)

# Count the number of active entries, and compute the total batch loss.
num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
loss_mean = tf.reduce_mean(losses)
# Count how many entries in the batch are (possibly approximately) non-zero.
if args.loss_ignore_zero is True:
nnz = tf.count_nonzero(losses, dtype=tf.float32)
else:
nnz = tf.reduce_sum(tf.to_float(tf.greater(losses, args.loss_ignore_zero or 1e-5)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the point of this supposed to be just for logging?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, actually. It's type-magic and can be a little obscure, hence why I still need to write documentation in the README :)

The else case happens when loss_ignore_zero is given an additional float argument, so one can call it as --loss_ignore_zero 1e-3 for example, in order to consider anything below 1e-3 to be counted as zero.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read our paper, we explain them in there :) But really it's not a good time investment to play with that parameter.

Copy link
Contributor

@maxisme maxisme Apr 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to delete that comment because it makes no sense sorry. 😆 Currently have it printed and highlighted in front of me trying to get to grips!


# Compute the total batch-loss by either averaging all, or averaging non-zeros only.
if args.loss_ignore_zero is False:
loss_mean = tf.reduce_mean(losses)
else:
loss_mean = tf.reduce_sum(losses) / (1e-33 + nnz)

# Some logging for tensorboard.
tf.summary.histogram('loss_distribution', losses)
tf.summary.scalar('loss', loss_mean)
tf.summary.scalar('batch_top1', train_top1)
tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k-1), prec_at_k)
tf.summary.scalar('active_count', num_active)
tf.summary.scalar('active_count', nnz)
tf.summary.histogram('embedding_dists', dists)
tf.summary.histogram('embedding_pos_dists', pos_dists)
tf.summary.histogram('embedding_neg_dists', neg_dists)
Expand Down Expand Up @@ -341,9 +358,7 @@ def main():
else:
learning_rate = args.learning_rate
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate)
# Feel free to try others!
# optimizer = tf.train.AdadeltaOptimizer(learning_rate)
optimizer = eval("tf.train." + args.optim)

# Update_ops are used to update batchnorm stats.
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
Expand Down