Skip to content

Commit

Permalink
FIX the threshold by taking the opposite (to be adapted to the decisi…
Browse files Browse the repository at this point in the history
…on function)
  • Loading branch information
William de Vazelhes committed Feb 20, 2019
1 parent dc9e21d commit 402729f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def predict(self, pairs):
The predicted learned metric value between samples in every pair.
"""
check_is_fitted(self, ['threshold_', 'transformer_'])
return - 2 * (self.decision_function(pairs) > self.threshold_) + 1
return 2 * (self.decision_function(pairs) > self.threshold_) - 1

def decision_function(self, pairs):
"""Returns the decision function used to classify the pairs.
Expand Down Expand Up @@ -387,13 +387,13 @@ def score(self, pairs, y):
return roc_auc_score(y, self.decision_function(pairs))

def set_default_threshold(self, pairs, y):
"""Returns a threshold that is the mean between the similar metrics
mean, and the dissimilar metrics mean"""
similar_threshold = np.mean(self.decision_function(
"""Returns a threshold that is the opposite of the mean between the similar
metrics mean and the dissimilar metrics mean"""
similar_threshold = np.mean(self.score_pairs(
pairs[(y == 1).ravel()]))
dissimilar_threshold = np.mean(self.decision_function(
dissimilar_threshold = np.mean(self.score_pairs(
pairs[(y == -1).ravel()]))
self.threshold_ = np.mean([similar_threshold, dissimilar_threshold])
self.threshold_ = - np.mean([similar_threshold, dissimilar_threshold])


class _QuadrupletsClassifierMixin(BaseMetricLearner):
Expand Down
2 changes: 1 addition & 1 deletion metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def fit(self, pairs, y, bounds=None):
Returns the instance.
"""
self._fit(pairs, y, bounds=bounds)
self.threshold_ = np.mean(self.bounds_)
self.threshold_ = - np.mean(self.bounds_)
return self


Expand Down

0 comments on commit 402729f

Please sign in to comment.